In [1]:
import torch
import torch.nn as nn
import scipy.stats as stats
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import optim
import matplotlib.pyplot as plt

In [2]:
from icnnet import ICNNet
from mydataset import MyDataset, get_gaussian_dataset, get_gaussian_transport_dataset
from toy_data_dataloader_gaussian import generate_gaussian_dataset, get_dataset, generate_dataset
from train_picnn import PICNNtrain
from train_wasserstein import train_wasserstein
from train_makkuva import train_makkuva, train_makkuva_epoch
from visualization import plot_transport
from gaussian_transport import get_gaussian_transport

In [3]:
%load_ext autoreload
%autoreload 2

## __Generate dataset__


In [4]:
#dataset = get_dataset(d=2, r=100, N=500) #valou
#dataset = generate_gaussian_dataset(d=2, r=400, N=10000) #thomas
dataset = generate_dataset(d=2, r=1000, N=500)
gaussian_dataset = get_gaussian_dataset(dataset)
gaussian_transport_dataset = get_gaussian_transport_dataset(gaussian_dataset)

In [5]:
def get_mean(batch):
    means = torch.mean(batch, dim=1)
    average_mean = torch.mean(means, dim=0)
    return(average_mean)

def get_covariance(batch):
    n = batch.size(1) - 1
    mean = torch.mean(batch, dim=1, keepdim=True)
    batch = batch - mean  # Centering the data
    cov = torch.matmul(batch.transpose(1, 2), batch) / n
    return(torch.mean(cov, dim=0))

mean1 = get_mean(dataset.X)
cov1 = get_covariance(dataset.X)
mean2 = get_mean(dataset.Y)
cov2 = get_covariance(dataset.Y)

def init_z_f(x):
    #return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
    return(get_gaussian_transport(u=x, cov1 = cov1, cov2 = cov2, m1=mean1, m2=mean2))

def init_z_g(x) :
    #return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
    return(get_gaussian_transport(u=x, cov1 = cov2, cov2 = cov2, m1=mean2, m2=mean1))

## __Initialization__

### __PICNN training__

In [6]:
input_size = 2
layer_sizes = [input_size,64, 64, 64,64, 1]
n_layers = len(layer_sizes)

In [7]:
import torch.nn.functional as F

# def get_embedding(C, c):
#     scalar_product = torch.matmul(c.float(), C.t().float())
#     embedding = F.softmax(scalar_product, dim=1)
#     return(embedding)

context_layer_sizes = [12] * n_layers

In [8]:
model_init_f = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
model_init_g = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

In [9]:
n_epoch = 200
lr = 0.001

print('training f')
gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
PICNNtrain(model_init_f, gaussian_transport_dataloader, init_z_f, lr=lr, epochs=2*n_epoch)
#PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.0001, epochs=1, init_z = lambda x: x)

print('training g')
reversed_gaussian_dataset = MyDataset(gaussian_dataset.Y, gaussian_dataset.C, gaussian_dataset.X)
gaussian_transport_dataset_reversed = get_gaussian_transport_dataset(reversed_gaussian_dataset)
gaussian_transport_dataloader_reversed = DataLoader(gaussian_transport_dataset_reversed, batch_size=250, shuffle=True)
#PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.0001, epochs=25, init_z = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2)
PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, init_z_g, lr=lr, epochs=n_epoch)

training f
Epoch 1/400 Loss: 10620.298828125
Epoch 2/400 Loss: 2087.87353515625
Epoch 3/400 Loss: 446.8060302734375
Epoch 4/400 Loss: 100.56060409545898
Epoch 5/400 Loss: 20.45000648498535
Epoch 6/400 Loss: 2.7931596636772156
Epoch 7/400 Loss: 1.0180346965789795
Epoch 8/400 Loss: 3.0920804738998413
Epoch 9/400 Loss: 5.649677515029907
Epoch 10/400 Loss: 7.882863759994507
Epoch 11/400 Loss: 9.686431884765625
Epoch 12/400 Loss: 11.0792818069458
Epoch 13/400 Loss: 12.162493228912354
Epoch 14/400 Loss: 12.994392395019531
Epoch 15/400 Loss: 13.633917808532715
Epoch 16/400 Loss: 14.130440711975098
Epoch 17/400 Loss: 14.515892028808594
Epoch 18/400 Loss: 14.818449020385742
Epoch 19/400 Loss: 15.053760051727295
Epoch 20/400 Loss: 15.238939762115479
Epoch 21/400 Loss: 15.382769107818604
Epoch 22/400 Loss: 15.493432521820068
Epoch 23/400 Loss: 15.578864574432373
Epoch 24/400 Loss: 15.643084526062012
Epoch 25/400 Loss: 15.690220355987549
Epoch 26/400 Loss: 15.723490238189697
Epoch 27/400 Loss: 15.

In [None]:
state_dict_init_f = model_init_f.state_dict()
state_dict_init_g = model_init_g.state_dict()

###### Dorseuil

In [None]:
# print('training f')
# gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
# train_wasserstein(model_init_f, gaussian_transport_dataloader, lr=0.1, epochs=10, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

In [None]:
# X, Y, C = gaussian_dataset.X, gaussian_dataset.Y, gaussian_dataset.C
# #Calcul de la dérivée du PICNN

# for test in range(20):
#     x_i = X[test, :, :]
#     y_i = Y[test, :, :]
#     c_i = C[test, :, :]

#     locs = c_i[:,0]
#     #print(locs)

#     scales = c_i[:,1]
#     #print(scales)  


#     y_i.requires_grad_(True)
#     x_i.requires_grad_(True)
#     #c_i.requires_grad_(True)    

#     output_model_f = model_init_f(x_i, c_i)
#     grad_model_f = torch.autograd.grad(outputs=output_model_f, inputs=x_i, grad_outputs=torch.ones_like(output_model_f), create_graph=True)[0].detach().numpy()

#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True)
#     plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True)
#     plt.hist(grad_model_f[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()


#     output_model_g = model_init_g(y_i, c_i)
#     grad_model_g = torch.autograd.grad(outputs=output_model_g, inputs=y_i, grad_outputs=torch.ones_like(output_model_g), create_graph=True)[0].detach().numpy()
#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True, color = 'red')
#     #plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True, color = 'blue')
#     plt.hist(grad_model_g[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     #plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()

## __Makkuva__

In [None]:
# state_dict_init_f = torch.load('trained_models/training19/models/model_f_0.pth')
# state_dict_init_g = torch.load('trained_models/training19/models/model_g_0.pth')

In [None]:
ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

# Load the state dictionary into ICNNf and ICNNg
ICNNf.load_state_dict(state_dict_init_f)
ICNNg.load_state_dict(state_dict_init_g)


In [None]:
n_points = 2000
test = 21

filepath_pth_f = 'trained_models/training19/models/model_f_'
filepath_pth_g = 'trained_models/training19/models/model_g_'

filepath_plt_f = 'trained_models/training19/plots/model_f_'
filepath_plt_g = 'trained_models/training19/plots/model_g_'

import os
os.makedirs(filepath_pth_f, exist_ok=True)
os.makedirs(filepath_pth_g, exist_ok=True)

os.makedirs(filepath_plt_f, exist_ok=True)
os.makedirs(filepath_plt_g, exist_ok=True)

In [None]:
filename_pth_f = filepath_pth_f + str(0) + '.pth'
filename_pth_g = filepath_pth_g + str(0) + '.pth'
torch.save(ICNNf.state_dict(), filename_pth_f)
torch.save(ICNNg.state_dict(), filename_pth_g)

In [None]:
filename_plt_f = filepath_plt_f + str(0) + '.png'
filename_plt_g = filepath_plt_g + str(0) + '.png'
plot_transport(dataset, test, ICNNf, ICNNg, init_z_f = init_z_f, init_z_g = init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)

In [None]:
dataloader = DataLoader(dataset, batch_size=501, shuffle=True)

loss_f = list()
loss_g = list()

prev_param_f = [param.clone().detach() for param in ICNNf.parameters()]
prev_param_g = [param.clone().detach() for param in ICNNg.parameters()]

for epoch in range(1, 101) :
    print('epoch :', epoch)
    mean_loss_f, mean_loss_g, prev_param_f, prev_param_g = train_makkuva_epoch(ICNNf=ICNNf, ICNNg=ICNNg, prev_param_f=prev_param_f, prev_param_g=prev_param_g, dataloader = dataloader, init_z_f = init_z_f, init_z_g = init_z_g, lr=0.0005, train_freq_g=10, train_freq_f=1, regularize_f = True, regularize_g = True, lambda_proximal=0.0001)
    #mean_loss_f, mean_loss_g = train_makkuva_epoch(ICNNf, ICNNg, None, None, dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2, lr=0.0001, train_freq_g=10, train_freq_f=1, gaussian_transport=False)

    loss_f.append(mean_loss_f)
    loss_g.append(mean_loss_g)

    filename_plt_f = filepath_plt_f + str(epoch) + '.png'
    filename_plt_g = filepath_plt_g + str(epoch) + '.png'
    plot_transport(dataset, test, ICNNf, ICNNg, init_z_f=init_z_f, init_z_g=init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)  

    filename_pth_f = filepath_pth_f + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + str(epoch) + '.pth'
    torch.save(ICNNf.state_dict(), filename_pth_f)
    torch.save(ICNNg.state_dict(), filename_pth_g)

In [None]:
l_test = [12]
for epoch in range(1, 101) :
    epoch = epoch
    print('epoch :', epoch)

    filename_pth_f = filepath_pth_f + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + str(epoch) + '.pth'

    state_dict_init_f = torch.load(filename_pth_f)
    state_dict_init_g = torch.load(filename_pth_f)

    ICNNf.load_state_dict(state_dict_init_f)
    ICNNg.load_state_dict(state_dict_init_g)

    for test in l_test :
        filename_plt_f = filepath_plt_f + 'test_' + str(test) + '_'  +str(epoch) + '.png'
        filename_plt_g = filepath_plt_g + 'test_' + str(test) + '_' + str(epoch) + '.png'
        plot_transport(dataset, test, ICNNf, ICNNg, init_z_f=init_z_f, init_z_g=init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)  

In [None]:
# for ele in loss_f:
#     print(ele)

# print('stop')

# for ele in loss_g:
#     print(ele)