In [52]:
import torch
import torch.nn as nn
import scipy.stats as stats
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import optim
import matplotlib.pyplot as plt

In [53]:
from icnnet import ICNNet
from mydataset import MyDataset, get_gaussian_dataset, get_gaussian_transport_dataset
from toy_data_dataloader_gaussian import generate_gaussian_dataset, get_dataset, generate_dataset
from train_picnn import PICNNtrain
from train_wasserstein import train_wasserstein
from train_makkuva import train_makkuva, train_makkuva_epoch
from visualization import plot_transport
from gaussian_transport import get_gaussian_transport

In [54]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## __Generate dataset__


In [55]:
#dataset = get_dataset(d=2, r=100, N=500) #valou
#dataset = generate_gaussian_dataset(d=2, r=400, N=10000) #thomas
dataset = generate_dataset(d=2, r=1000, N=50)
gaussian_dataset = get_gaussian_dataset(dataset)
gaussian_transport_dataset = get_gaussian_transport_dataset(gaussian_dataset)

In [56]:
def get_mean(batch):
    means = torch.mean(batch, dim=1)
    average_mean = torch.mean(means, dim=0)
    return(average_mean)

def get_covariance(batch):
    n = batch.size(1) - 1
    mean = torch.mean(batch, dim=1, keepdim=True)
    batch = batch - mean  # Centering the data
    cov = torch.matmul(batch.transpose(1, 2), batch) / n
    return(torch.mean(cov, dim=0))

mean1 = get_mean(dataset.X)
cov1 = get_covariance(dataset.X)
mean2 = get_mean(dataset.Y)
cov2 = get_covariance(dataset.Y)

def init_z_f(x):
    #return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
    return(get_gaussian_transport(u=x, cov1 = cov1, cov2 = cov2, m1=mean1, m2=mean2))

def init_z_g(x) :
    #return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
    return(get_gaussian_transport(u=x, cov1 = cov2, cov2 = cov2, m1=mean2, m2=mean1))

## __Initialization__

### __PICNN training__

In [57]:
input_size = 2
layer_sizes = [input_size,64, 64, 64, 64, 64, 64, 1]
n_layers = len(layer_sizes)

In [58]:
import torch.nn.functional as F

# def get_embedding(C, c):
#     scalar_product = torch.matmul(c.float(), C.t().float())
#     embedding = F.softmax(scalar_product, dim=1)
#     return(embedding)

context_layer_sizes = [12] * n_layers

In [59]:
model_init_f = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
model_init_g = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

In [62]:
n_epoch = 100
lr = 0.001

print('training f')
gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
PICNNtrain(model_init_f, gaussian_transport_dataloader, init_z_f, lr=lr, epochs=n_epoch)
#PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.0001, epochs=1, init_z = lambda x: x)

print('training g')
reversed_gaussian_dataset = MyDataset(gaussian_dataset.Y, gaussian_dataset.C, gaussian_dataset.X)
gaussian_transport_dataset_reversed = get_gaussian_transport_dataset(reversed_gaussian_dataset)
gaussian_transport_dataloader_reversed = DataLoader(gaussian_transport_dataset_reversed, batch_size=250, shuffle=True)
#PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.0001, epochs=25, init_z = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2)
PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, init_z_g, lr=lr, epochs=n_epoch)

training f
Epoch 1/100 Loss: 0.11959782242774963
Epoch 2/100 Loss: 104.4300537109375
Epoch 3/100 Loss: 4.114497184753418
Epoch 4/100 Loss: 6.26355504989624
Epoch 5/100 Loss: 22.026565551757812
Epoch 6/100 Loss: 33.934425354003906
Epoch 7/100 Loss: 41.351898193359375
Epoch 8/100 Loss: 45.47114562988281
Epoch 9/100 Loss: 47.20033645629883
Epoch 10/100 Loss: 47.06255340576172
Epoch 11/100 Loss: 45.302486419677734
Epoch 12/100 Loss: 41.990604400634766
Epoch 13/100 Loss: 37.09465408325195
Epoch 14/100 Loss: 30.568370819091797
Epoch 15/100 Loss: 22.4964599609375
Epoch 16/100 Loss: 13.364995956420898
Epoch 17/100 Loss: 4.845413684844971
Epoch 18/100 Loss: 0.23548443615436554
Epoch 19/100 Loss: 3.6036977767944336
Epoch 20/100 Loss: 13.103581428527832
Epoch 21/100 Loss: 17.63178062438965
Epoch 22/100 Loss: 12.13708782196045
Epoch 23/100 Loss: 4.162672519683838
Epoch 24/100 Loss: 0.33997035026550293
Epoch 25/100 Loss: 0.9143361449241638
Epoch 26/100 Loss: 3.4462215900421143
Epoch 27/100 Loss: 5.

In [63]:
state_dict_init_f = model_init_f.state_dict()
state_dict_init_g = model_init_g.state_dict()

###### Dorseuil

In [64]:
# print('training f')
# gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
# train_wasserstein(model_init_f, gaussian_transport_dataloader, lr=0.1, epochs=10, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

In [65]:
# X, Y, C = gaussian_dataset.X, gaussian_dataset.Y, gaussian_dataset.C
# #Calcul de la dérivée du PICNN

# for test in range(20):
#     x_i = X[test, :, :]
#     y_i = Y[test, :, :]
#     c_i = C[test, :, :]

#     locs = c_i[:,0]
#     #print(locs)

#     scales = c_i[:,1]
#     #print(scales)  


#     y_i.requires_grad_(True)
#     x_i.requires_grad_(True)
#     #c_i.requires_grad_(True)    

#     output_model_f = model_init_f(x_i, c_i)
#     grad_model_f = torch.autograd.grad(outputs=output_model_f, inputs=x_i, grad_outputs=torch.ones_like(output_model_f), create_graph=True)[0].detach().numpy()

#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True)
#     plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True)
#     plt.hist(grad_model_f[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()


#     output_model_g = model_init_g(y_i, c_i)
#     grad_model_g = torch.autograd.grad(outputs=output_model_g, inputs=y_i, grad_outputs=torch.ones_like(output_model_g), create_graph=True)[0].detach().numpy()
#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True, color = 'red')
#     #plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True, color = 'blue')
#     plt.hist(grad_model_g[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     #plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()

## __Makkuva__

In [66]:
# state_dict_init_f = torch.load('trained_models/training16/models/model_f_0.pth')
# state_dict_init_g = torch.load('trained_models/training16/models/model_g_0.pth')

In [95]:
ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

old_ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
old_ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

# Load the state dictionary into ICNNf and ICNNg
ICNNf.load_state_dict(state_dict_init_f)
ICNNg.load_state_dict(state_dict_init_g)

old_ICNNf.load_state_dict(state_dict_init_f)
old_ICNNg.load_state_dict(state_dict_init_g)

l_ICNNf = [ICNNf, old_ICNNf]
l_ICNNg = [ICNNg, old_ICNNg]

In [96]:
n_points = 2000
test = 21

filepath_pth_f = 'trained_models/training16/models/model_f_'
filepath_pth_g = 'trained_models/training16/models/model_g_'

filepath_plt_f = 'trained_models/training16/plots/model_f_'
filepath_plt_g = 'trained_models/training16/plots/model_g_'

import os
os.makedirs(filepath_pth_f, exist_ok=True)
os.makedirs(filepath_pth_g, exist_ok=True)

os.makedirs(filepath_plt_f, exist_ok=True)
os.makedirs(filepath_plt_g, exist_ok=True)

In [97]:
filename_pth_f = filepath_pth_f + str(0) + '.pth'
filename_pth_g = filepath_pth_g + str(0) + '.pth'
torch.save(ICNNf.state_dict(), filename_pth_f)
torch.save(ICNNg.state_dict(), filename_pth_g)

In [98]:
filename_plt_f = filepath_plt_f + str(0) + '.png'
filename_plt_g = filepath_plt_g + str(0) + '.png'
plot_transport(dataset, test, ICNNf, ICNNg, init_z_f = init_z_f, init_z_g = init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)

<Figure size 640x480 with 0 Axes>

In [99]:
dataloader = DataLoader(dataset, batch_size=501, shuffle=True)

loss_f = list()
loss_g = list()

prev_param_f = [param.clone().detach() for param in ICNNf.parameters()]
prev_param_g = [param.clone().detach() for param in ICNNg.parameters()]

for epoch in range(1, 501) :
    print('epoch :', epoch)
    mean_loss_f, mean_loss_g, prev_param_f, prev_param_g = train_makkuva_epoch(ICNNf=ICNNf, ICNNg=ICNNg, prev_param_f=prev_param_f, prev_param_g=prev_param_g, dataloader = dataloader, init_z_f = init_z_f, init_z_g = init_z_g, lr=0.01, train_freq_g=10, train_freq_f=1, regularize_f = True, regularize_g = True, lambda_proximal=0.001)
    #mean_loss_f, mean_loss_g = train_makkuva_epoch(ICNNf, ICNNg, None, None, dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2, lr=0.0001, train_freq_g=10, train_freq_f=1, gaussian_transport=False)

    loss_f.append(mean_loss_f)
    loss_g.append(mean_loss_g)

    filename_pth_f = filepath_pth_f + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + str(epoch) + '.pth'
    torch.save(l_ICNNf[epoch%2].state_dict(), filename_pth_f)
    torch.save(l_ICNNg[epoch%2].state_dict(), filename_pth_g)

    filename_plt_f = filepath_plt_f + str(epoch) + '.png'
    filename_plt_g = filepath_plt_g + str(epoch) + '.png'
    plot_transport(dataset, test, ICNNf, ICNNg, init_z_f=init_z_f, init_z_g=init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)  

epoch : 1
R_f 6.080854655010626e-05
proximal_term 10.694985389709473
loss_g: 20.003807067871094, loss_f: 10.744110107421875
epoch : 2
R_f 0.00017087519518099725
proximal_term 10.706409454345703
loss_g: 17.71279525756836, loss_f: 7.049393177032471
epoch : 3
name layers_z.0.weight
name layers_z.1.weight
name layers_z.2.weight
name layers_z.3.weight
name layers_z.4.weight
name layers_z.5.weight
name layers_z.6.weight
name layers_zu.0.0.weight
name layers_zu.0.0.bias
name layers_zu.1.0.weight
name layers_zu.1.0.bias
name layers_zu.2.0.weight
name layers_zu.2.0.bias
name layers_zu.3.0.weight
name layers_zu.3.0.bias
name layers_zu.4.0.weight
name layers_zu.4.0.bias
name layers_zu.5.0.weight
name layers_zu.5.0.bias
name layers_zu.6.0.weight
name layers_zu.6.0.bias
name layers_x.0.weight
name layers_x.1.weight
name layers_x.2.weight
name layers_x.3.weight
name layers_x.4.weight
name layers_x.5.weight
name layers_x.6.weight
name layers_xu.0.weight
name layers_xu.0.bias
name layers_xu.1.weight
n

UnboundLocalError: local variable 'param_simulated' referenced before assignment

<Figure size 640x480 with 0 Axes>

In [None]:
# for ele in loss_f:
#     print(ele)

# print('stop')

# for ele in loss_g:
#     print(ele)

In [None]:
ICNNf.load_state_dict(torch.load(filepath_pth_f + '0.pth'))
ICNNg.load_state_dict(torch.load(filepath_pth_g + '0.pth'))

filename_plt_f = 'trained_models/training16/plots/model_f_test'
filename_plt_g = 'trained_models/training16/plots/model_g_test'

plot_transport(dataset, test, ICNNf, ICNNg, filename_f = filename_plt_f, filename_g = filename_plt_g)