### __Imports__

In [19]:
import torch
import torch.nn as nn
import scipy.stats as stats
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import optim
import matplotlib.pyplot as plt

In [20]:
from icnnet import ICNNet
from mydataset import MyDataset, get_gaussian_dataset, get_gaussian_transport_dataset
from toy_data_dataloader_gaussian import generate_gaussian_dataset, get_dataset, generate_dataset
from train_picnn import PICNNtrain
from train_wasserstein import train_wasserstein
from train_makkuva import train_makkuva, train_makkuva_epoch
from visualization import plot_transport

In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## __Generate dataset__


In [22]:
#dataset = get_dataset(d=2, r=100, N=500) #valou
#dataset = generate_gaussian_dataset(d=2, r=400, N=10000) #thomas
dataset = generate_dataset(d=2, r=200, N=200)
gaussian_dataset = get_gaussian_dataset(dataset)
gaussian_transport_dataset = get_gaussian_transport_dataset(gaussian_dataset)

## __Initialization__

### __PICNN training__

In [23]:
input_size = 2
layer_sizes = [input_size,64,64,64, 1]
n_layers = len(layer_sizes)

In [24]:
import torch.nn.functional as F

def get_embedding(C, c):
    scalar_product = torch.matmul(c.float(), C.t().float())
    embedding = F.softmax(scalar_product, dim=1)
    return(embedding)

context_layer_sizes = [2] * n_layers

In [25]:
model_init_f = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
model_init_g = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

In [26]:
print('training f')
gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.001, epochs=150, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)
#PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.0001, epochs=1, init_z = lambda x: x)


print('training g')
reversed_gaussian_dataset = MyDataset(gaussian_dataset.Y, gaussian_dataset.C, gaussian_dataset.X)
gaussian_transport_dataset_reversed = get_gaussian_transport_dataset(reversed_gaussian_dataset)
gaussian_transport_dataloader_reversed = DataLoader(gaussian_transport_dataset_reversed, batch_size=250, shuffle=True)
#PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.0001, epochs=25, init_z = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2)
PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.001, epochs=150, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

training f
Epoch 1/150 Loss: 14.007146835327148
Epoch 2/150 Loss: 10.978131294250488
Epoch 3/150 Loss: 11.847502708435059
Epoch 4/150 Loss: 10.56534194946289
Epoch 5/150 Loss: 9.083248138427734
Epoch 6/150 Loss: 8.584258079528809
Epoch 7/150 Loss: 8.428940773010254
Epoch 8/150 Loss: 7.916776657104492
Epoch 9/150 Loss: 7.059391975402832
Epoch 10/150 Loss: 6.232521057128906
Epoch 11/150 Loss: 5.73097562789917
Epoch 12/150 Loss: 5.496456623077393
Epoch 13/150 Loss: 5.206870079040527
Epoch 14/150 Loss: 4.701727867126465
Epoch 15/150 Loss: 4.13549280166626
Epoch 16/150 Loss: 3.7138140201568604
Epoch 17/150 Loss: 3.4637210369110107
Epoch 18/150 Loss: 3.2625577449798584
Epoch 19/150 Loss: 2.9978554248809814
Epoch 20/150 Loss: 2.670083522796631
Epoch 21/150 Loss: 2.364135980606079
Epoch 22/150 Loss: 2.154249906539917
Epoch 23/150 Loss: 2.030212879180908
Epoch 24/150 Loss: 1.9112662076950073
Epoch 25/150 Loss: 1.7414461374282837
Epoch 26/150 Loss: 1.5500085353851318
Epoch 27/150 Loss: 1.3984699

In [27]:
state_dict_init_f = model_init_f.state_dict()
state_dict_init_g = model_init_g.state_dict()

In [28]:
# test=85
# n_points = 1000
# plot_transport(dataset, test, model_init_f, model_init_g, n_points=n_points)

In [29]:
# print('training f')
# gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
# train_wasserstein(model_init_f, gaussian_transport_dataloader, lr=0.1, epochs=10, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

In [30]:
# X, Y, C = gaussian_dataset.X, gaussian_dataset.Y, gaussian_dataset.C
# #Calcul de la dérivée du PICNN

# for test in range(20):
#     x_i = X[test, :, :]
#     y_i = Y[test, :, :]
#     c_i = C[test, :, :]

#     locs = c_i[:,0]
#     #print(locs)

#     scales = c_i[:,1]
#     #print(scales)  


#     y_i.requires_grad_(True)
#     x_i.requires_grad_(True)
#     #c_i.requires_grad_(True)    

#     output_model_f = model_init_f(x_i, c_i)
#     grad_model_f = torch.autograd.grad(outputs=output_model_f, inputs=x_i, grad_outputs=torch.ones_like(output_model_f), create_graph=True)[0].detach().numpy()

#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True)
#     plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True)
#     plt.hist(grad_model_f[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()


#     output_model_g = model_init_g(y_i, c_i)
#     grad_model_g = torch.autograd.grad(outputs=output_model_g, inputs=y_i, grad_outputs=torch.ones_like(output_model_g), create_graph=True)[0].detach().numpy()
#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True, color = 'red')
#     #plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True, color = 'blue')
#     plt.hist(grad_model_g[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     #plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()

## __Makkuva__

In [1]:
# state_dict_init_f = torch.load('trained_models/training6/models/model_f_0.pth')
# state_dict_init_g = torch.load('trained_models/training6/models/model_g_0.pth')

In [2]:
ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

old_ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
old_ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

# Load the state dictionary into ICNNf and ICNNg
ICNNf.load_state_dict(state_dict_init_f)
ICNNg.load_state_dict(state_dict_init_g)

old_ICNNf.load_state_dict(state_dict_init_f)
old_ICNNg.load_state_dict(state_dict_init_g)

l_ICNNf = [ICNNf, old_ICNNf]
l_ICNNg = [ICNNg, old_ICNNg]

NameError: name 'ICNNet' is not defined

In [None]:
n_points = 1000
test = 0

filepath_pth_f = 'trained_models/training6/models/model_f_'
filepath_pth_g = 'trained_models/training6/models/model_g_'

filepath_plt_f = 'trained_models/training6/plots/model_f_'
filepath_plt_g = 'trained_models/training6/plots/model_g_'

In [None]:
filename_pth_f = filepath_pth_f + str(0) + '.pth'
filename_pth_g = filepath_pth_g + str(0) + '.pth'
torch.save(ICNNf.state_dict(), filename_pth_f)
torch.save(ICNNg.state_dict(), filename_pth_g)

In [None]:
filename_plt_f = filepath_plt_f + str(0) + '.png'
filename_plt_g = filepath_plt_g + str(0) + '.png'
plot_transport(dataset, test, ICNNf, ICNNg, filename_f = filename_plt_f, filename_g = filename_plt_g)

<Figure size 1280x960 with 0 Axes>

In [None]:
dataloader = DataLoader(dataset, batch_size=500, shuffle=True)

loss_f = list()
loss_g = list()

for epoch in range(1, 31) :
    print('epoch :', epoch, end=('\r'))
    mean_loss_f, mean_loss_g = train_makkuva_epoch(l_ICNNf[epoch%2], l_ICNNg[epoch%2], l_ICNNf[1 - epoch%2], l_ICNNg[1 - epoch%2], dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, lr=0.001, train_freq_g=10, train_freq_f=1, gaussian_transport=False, regularize_f = True, regularize_g = True)
    #mean_loss_f, mean_loss_g = train_makkuva_epoch(ICNNf, ICNNg, None, None, dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2, lr=0.0001, train_freq_g=10, train_freq_f=1, gaussian_transport=False)

    loss_f.append(mean_loss_f)
    loss_g.append(mean_loss_g)

    filename_pth_f = filepath_pth_f + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + str(epoch) + '.pth'
    torch.save(l_ICNNf[epoch%2].state_dict(), filename_pth_f)
    torch.save(l_ICNNg[epoch%2].state_dict(), filename_pth_g)

    filename_plt_f = filepath_plt_f + str(epoch) + '.png'
    filename_plt_g = filepath_plt_g + str(epoch) + '.png'
    plot_transport(dataset, test, l_ICNNf[epoch%2], l_ICNNg[epoch%2], filename_f = filename_plt_f, filename_g = filename_plt_g)

train_freq_g 1
train_freq_f 1
loss_g: 5.22470760345459, loss_f: 3.7670202255249023
train_freq_g 1
train_freq_f 1
loss_g: 5.22470760345459, loss_f: 3.7670202255249023
train_freq_g 1
train_freq_f 1
loss_g: 4.409957408905029, loss_f: 2.8092312812805176
train_freq_g 1
train_freq_f 1
loss_g: 4.409956932067871, loss_f: 2.8092312812805176
train_freq_g 1
train_freq_f 1
loss_g: 3.376021146774292, loss_f: 1.7686762809753418
train_freq_g 1
train_freq_f 1
loss_g: 3.376020908355713, loss_f: 1.7686761617660522
train_freq_g 1
train_freq_f 1
loss_g: 1.8238725662231445, loss_f: 0.6411782503128052
train_freq_g 1
train_freq_f 1
loss_g: 1.8238725662231445, loss_f: 0.6411781907081604
train_freq_g 1
train_freq_f 1
loss_g: -0.4731605648994446, loss_f: -0.33082646131515503
train_freq_g 1
train_freq_f 1


: 

In [None]:
print(len(l_ICNNf))

2


In [None]:
for ele in loss_f:
    print(ele)

print('stop')

for ele in loss_g:
    print(ele)

3.7670202255249023
3.7670202255249023
2.8092315196990967
2.8092312812805176
1.7686762809753418
1.7686762809753418
0.6411781907081604
0.6411782503128052
-0.33082661032676697
-0.3308264911174774
-2.9172003269195557
-2.9172000885009766
-11.605838775634766
-11.605839729309082
-6.383660316467285
-6.383660316467285
-1.5622224807739258
-1.5622214078903198
2.2154698371887207
2.2155182361602783
-1.558901071548462
-1.5589052438735962
2.2143967151641846
2.215229034423828
-1.5539886951446533
-1.5544174909591675
2.211855888366699
2.21500563621521
-1.5496015548706055
-1.5500071048736572
stop
5.224708080291748
5.22470760345459
4.409956932067871
4.409957408905029
3.376020908355713
3.376021146774292
1.8238725662231445
1.823872685432434
-0.47316065430641174
-0.47316059470176697
-1.9266314506530762
-1.926631212234497
-1.5410505533218384
-1.5410505533218384
5.427972793579102
5.427972316741943
5.0130720138549805
5.0130720138549805
4.734468460083008
4.734989166259766
2.835937261581421
2.835942029953003
4.72

In [None]:
ICNNf.load_state_dict(torch.load(filepath_pth_f + '0.pth'))
ICNNg.load_state_dict(torch.load(filepath_pth_g + '0.pth'))

filename_plt_f = 'trained_models/training6/plots/model_f_test'
filename_plt_g = 'trained_models/training6/plots/model_g_test'

plot_transport(dataset, test, ICNNf, ICNNg, filename_f = filename_plt_f, filename_g = filename_plt_g)

<Figure size 1280x960 with 0 Axes>