In [1]:
import torch
import torch.nn as nn
import scipy.stats as stats
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import optim
import matplotlib.pyplot as plt

In [2]:
from icnnet import ICNNet
from mydataset import MyDataset, get_gaussian_dataset, get_gaussian_transport_dataset
from toy_data_dataloader_gaussian import generate_gaussian_dataset, get_dataset, generate_dataset
from train_picnn import PICNNtrain
from train_wasserstein import train_wasserstein
from train_makkuva import train_makkuva, train_makkuva_epoch
from visualization import plot_transport

In [3]:
%load_ext autoreload
%autoreload 2

## __Generate dataset__


In [4]:
#dataset = get_dataset(d=2, r=100, N=500) #valou
#dataset = generate_gaussian_dataset(d=2, r=400, N=10000) #thomas
dataset = generate_dataset(d=2, r=1000, N=50)
gaussian_dataset = get_gaussian_dataset(dataset)
gaussian_transport_dataset = get_gaussian_transport_dataset(gaussian_dataset)

  u = torch.tensor(u)


## __Initialization__

### __PICNN training__

In [5]:
input_size = 2
layer_sizes = [input_size,64,64,64, 1]
n_layers = len(layer_sizes)

In [6]:
import torch.nn.functional as F

def get_embedding(C, c):
    scalar_product = torch.matmul(c.float(), C.t().float())
    embedding = F.softmax(scalar_product, dim=1)
    return(embedding)

context_layer_sizes = [12] * n_layers

In [7]:
model_init_f = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
model_init_g = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

In [10]:
print('training f')
gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.0001, epochs=250, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)
#PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.0001, epochs=1, init_z = lambda x: x)


print('training g')
reversed_gaussian_dataset = MyDataset(gaussian_dataset.Y, gaussian_dataset.C, gaussian_dataset.X)
gaussian_transport_dataset_reversed = get_gaussian_transport_dataset(reversed_gaussian_dataset)
gaussian_transport_dataloader_reversed = DataLoader(gaussian_transport_dataset_reversed, batch_size=250, shuffle=True)
#PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.0001, epochs=25, init_z = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2)
PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.0001, epochs=250, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

training f
Epoch 1/250 Loss: 0.057322945445775986
Epoch 2/250 Loss: 0.056910786777734756
Epoch 3/250 Loss: 0.056491900235414505
Epoch 4/250 Loss: 0.05607828125357628
Epoch 5/250 Loss: 0.05567314475774765
Epoch 6/250 Loss: 0.05527341738343239
Epoch 7/250 Loss: 0.054877109825611115
Epoch 8/250 Loss: 0.054483067244291306
Epoch 9/250 Loss: 0.054088570177555084
Epoch 10/250 Loss: 0.053700536489486694
Epoch 11/250 Loss: 0.05331666022539139
Epoch 12/250 Loss: 0.05293552204966545
Epoch 13/250 Loss: 0.052555982023477554
Epoch 14/250 Loss: 0.05217800661921501
Epoch 15/250 Loss: 0.051803626120090485
Epoch 16/250 Loss: 0.05143744498491287
Epoch 17/250 Loss: 0.05107295885682106
Epoch 18/250 Loss: 0.05070872604846954
Epoch 19/250 Loss: 0.05034996196627617
Epoch 20/250 Loss: 0.0499957799911499
Epoch 21/250 Loss: 0.04964367300271988
Epoch 22/250 Loss: 0.049294330179691315
Epoch 23/250 Loss: 0.048947401344776154
Epoch 24/250 Loss: 0.0486033633351326
Epoch 25/250 Loss: 0.0482613667845726
Epoch 26/250 Lo

In [11]:
state_dict_init_f = model_init_f.state_dict()
state_dict_init_g = model_init_g.state_dict()

In [None]:
# print('training f')
# gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
# train_wasserstein(model_init_f, gaussian_transport_dataloader, lr=0.1, epochs=10, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

In [None]:
# X, Y, C = gaussian_dataset.X, gaussian_dataset.Y, gaussian_dataset.C
# #Calcul de la dérivée du PICNN

# for test in range(20):
#     x_i = X[test, :, :]
#     y_i = Y[test, :, :]
#     c_i = C[test, :, :]

#     locs = c_i[:,0]
#     #print(locs)

#     scales = c_i[:,1]
#     #print(scales)  


#     y_i.requires_grad_(True)
#     x_i.requires_grad_(True)
#     #c_i.requires_grad_(True)    

#     output_model_f = model_init_f(x_i, c_i)
#     grad_model_f = torch.autograd.grad(outputs=output_model_f, inputs=x_i, grad_outputs=torch.ones_like(output_model_f), create_graph=True)[0].detach().numpy()

#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True)
#     plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True)
#     plt.hist(grad_model_f[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()


#     output_model_g = model_init_g(y_i, c_i)
#     grad_model_g = torch.autograd.grad(outputs=output_model_g, inputs=y_i, grad_outputs=torch.ones_like(output_model_g), create_graph=True)[0].detach().numpy()
#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True, color = 'red')
#     #plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True, color = 'blue')
#     plt.hist(grad_model_g[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     #plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()

## __Makkuva__

In [None]:
# state_dict_init_f = torch.load('trained_models/training9/models/model_f_0.pth')
# state_dict_init_g = torch.load('trained_models/training9/models/model_g_0.pth')

In [43]:
ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

old_ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
old_ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

# Load the state dictionary into ICNNf and ICNNg
ICNNf.load_state_dict(state_dict_init_f)
ICNNg.load_state_dict(state_dict_init_g)

old_ICNNf.load_state_dict(state_dict_init_f)
old_ICNNg.load_state_dict(state_dict_init_g)

l_ICNNf = [ICNNf, old_ICNNf]
l_ICNNg = [ICNNg, old_ICNNg]

In [44]:
n_points = 2000
test = 16

filepath_pth_f = 'trained_models/training9/models/model_f_'
filepath_pth_g = 'trained_models/training9/models/model_g_'

filepath_plt_f = 'trained_models/training9/plots/model_f_'
filepath_plt_g = 'trained_models/training9/plots/model_g_'

In [45]:
filename_pth_f = filepath_pth_f + str(0) + '.pth'
filename_pth_g = filepath_pth_g + str(0) + '.pth'
torch.save(ICNNf.state_dict(), filename_pth_f)
torch.save(ICNNg.state_dict(), filename_pth_g)

In [46]:
filename_plt_f = filepath_plt_f + str(0) + '.png'
filename_plt_g = filepath_plt_g + str(0) + '.png'
plot_transport(dataset, test, ICNNf, ICNNg, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)

<Figure size 1280x960 with 0 Axes>

In [47]:
dataloader = DataLoader(dataset, batch_size=500, shuffle=True)

loss_f = list()
loss_g = list()

for epoch in range(1, 201) :
    print('epoch :', epoch, end=('\r'))
    mean_loss_f, mean_loss_g = train_makkuva_epoch(l_ICNNf[epoch%2], l_ICNNg[epoch%2], l_ICNNf[1 - epoch%2], l_ICNNg[1 - epoch%2], dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, lr=0.001, train_freq_g=10, train_freq_f=1, gaussian_transport=False, regularize_f = False, regularize_g = True)
    #mean_loss_f, mean_loss_g = train_makkuva_epoch(ICNNf, ICNNg, None, None, dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2, lr=0.0001, train_freq_g=10, train_freq_f=1, gaussian_transport=False)

    loss_f.append(mean_loss_f)
    loss_g.append(mean_loss_g)

    filename_pth_f = filepath_pth_f + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + str(epoch) + '.pth'
    torch.save(l_ICNNf[epoch%2].state_dict(), filename_pth_f)
    torch.save(l_ICNNg[epoch%2].state_dict(), filename_pth_g)

    filename_plt_f = filepath_plt_f + str(epoch) + '.png'
    filename_plt_g = filepath_plt_g + str(epoch) + '.png'
    plot_transport(dataset, test, l_ICNNf[epoch%2], l_ICNNg[epoch%2], filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)

loss_g: 0.7187504172325134, loss_f: -0.12365482747554779
loss_g: 0.7187504172325134, loss_f: -0.12365474551916122
loss_g: 0.9674621820449829, loss_f: 0.3908659815788269
loss_g: 0.9674623012542725, loss_f: 0.3908659517765045
loss_g: 0.6865757703781128, loss_f: -0.22369220852851868
loss_g: 0.6865753531455994, loss_f: -0.22369146347045898
loss_g: 1.0721560716629028, loss_f: -0.006209456827491522
loss_g: 1.072157621383667, loss_f: -0.006239877548068762
loss_g: 0.9000243544578552, loss_f: -0.07690823078155518
loss_g: 0.9001421332359314, loss_f: -0.07452258467674255
loss_g: 1.0317668914794922, loss_f: 0.3227865695953369
loss_g: 1.0298584699630737, loss_f: 0.3249962627887726
loss_g: 0.677795946598053, loss_f: -0.25361230969429016
loss_g: 0.6760386824607849, loss_f: -0.2901286780834198
loss_g: 0.9672154784202576, loss_f: 0.14428408443927765
loss_g: 0.9963656067848206, loss_f: 0.05697343870997429
loss_g: 0.7823339104652405, loss_f: 0.0631287470459938
loss_g: 0.8589330315589905, loss_f: 0.102934

In [None]:
for ele in loss_f:
    print(ele)

print('stop')

for ele in loss_g:
    print(ele)

In [None]:
ICNNf.load_state_dict(torch.load(filepath_pth_f + '0.pth'))
ICNNg.load_state_dict(torch.load(filepath_pth_g + '0.pth'))

filename_plt_f = 'trained_models/training9/plots/model_f_test'
filename_plt_g = 'trained_models/training9/plots/model_g_test'

plot_transport(dataset, test, ICNNf, ICNNg, filename_f = filename_plt_f, filename_g = filename_plt_g)