In [287]:
import torch
import torch.nn as nn
import scipy.stats as stats
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import optim
import matplotlib.pyplot as plt

In [288]:
from icnnet import ICNNet
from mydataset import MyDataset, get_gaussian_dataset, get_gaussian_transport_dataset
from toy_data_dataloader_gaussian import generate_gaussian_dataset, get_dataset, generate_dataset
from train_picnn import PICNNtrain
from train_wasserstein import train_wasserstein
from train_makkuva import train_makkuva, train_makkuva_epoch
from visualization import plot_transport
from gaussian_transport import get_gaussian_transport

In [289]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## __Generate dataset__


In [290]:
#dataset = get_dataset(d=2, r=100, N=500) #valou
#dataset = generate_gaussian_dataset(d=2, r=400, N=10000) #thomas
dataset = generate_dataset(d=2, r=1000, N=50)
gaussian_dataset = get_gaussian_dataset(dataset)
gaussian_transport_dataset = get_gaussian_transport_dataset(gaussian_dataset)

In [291]:
def get_mean(batch):
    means = torch.mean(batch, dim=1)
    average_mean = torch.mean(means, dim=0)
    return(average_mean)

def get_covariance(batch):
    n = batch.size(1) - 1
    mean = torch.mean(batch, dim=1, keepdim=True)
    batch = batch - mean  # Centering the data
    cov = torch.matmul(batch.transpose(1, 2), batch) / n
    return(torch.mean(cov, dim=0))

mean1 = get_mean(dataset.X)
cov1 = get_covariance(dataset.X)
mean2 = get_mean(dataset.Y)
cov2 = get_covariance(dataset.Y)

def init_z_f(x):
    #return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
    return(get_gaussian_transport(u=x, cov1 = cov1, cov2 = cov2, m1=mean1, m2=mean2))

def init_z_g(x) :
    #return (1/2) * torch.norm(x, dim=-1, keepdim=True)**2
    return(get_gaussian_transport(u=x, cov1 = cov2, cov2 = cov2, m1=mean2, m2=mean1))

## __Initialization__

### __PICNN training__

In [292]:
input_size = 2
layer_sizes = [input_size,64, 64, 64, 1]
n_layers = len(layer_sizes)

In [293]:
import torch.nn.functional as F

# def get_embedding(C, c):
#     scalar_product = torch.matmul(c.float(), C.t().float())
#     embedding = F.softmax(scalar_product, dim=1)
#     return(embedding)

context_layer_sizes = [12] * n_layers

In [294]:
model_init_f = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
model_init_g = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

In [295]:
n_epoch = 500
lr = 0.001

print('training f')
gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
PICNNtrain(model_init_f, gaussian_transport_dataloader, init_z_f, lr=lr, epochs=n_epoch)
#PICNNtrain(model_init_f, gaussian_transport_dataloader, lr=0.0001, epochs=1, init_z = lambda x: x)

print('training g')
reversed_gaussian_dataset = MyDataset(gaussian_dataset.Y, gaussian_dataset.C, gaussian_dataset.X)
gaussian_transport_dataset_reversed = get_gaussian_transport_dataset(reversed_gaussian_dataset)
gaussian_transport_dataloader_reversed = DataLoader(gaussian_transport_dataset_reversed, batch_size=250, shuffle=True)
#PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, lr=0.0001, epochs=25, init_z = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2)
PICNNtrain(model_init_g, gaussian_transport_dataloader_reversed, init_z_g, lr=lr, epochs=n_epoch)

training f
Epoch 1/500 Loss: 38.021671295166016
Epoch 2/500 Loss: 37.6040153503418
Epoch 3/500 Loss: 37.47692108154297
Epoch 4/500 Loss: 37.34761428833008
Epoch 5/500 Loss: 37.213741302490234
Epoch 6/500 Loss: 37.0707893371582
Epoch 7/500 Loss: 36.913421630859375
Epoch 8/500 Loss: 36.73466873168945
Epoch 9/500 Loss: 36.5279655456543
Epoch 10/500 Loss: 36.287384033203125
Epoch 11/500 Loss: 36.00726318359375
Epoch 12/500 Loss: 35.6817741394043
Epoch 13/500 Loss: 35.30423355102539
Epoch 14/500 Loss: 34.86676025390625
Epoch 15/500 Loss: 34.35957336425781
Epoch 16/500 Loss: 33.77012634277344
Epoch 17/500 Loss: 33.082374572753906
Epoch 18/500 Loss: 32.27539825439453
Epoch 19/500 Loss: 31.320959091186523
Epoch 20/500 Loss: 30.17909049987793
Epoch 21/500 Loss: 28.79749298095703
Epoch 22/500 Loss: 27.10463523864746
Epoch 23/500 Loss: 25.004173278808594
Epoch 24/500 Loss: 22.377017974853516
Epoch 25/500 Loss: 19.078102111816406
Epoch 26/500 Loss: 14.976136207580566
Epoch 27/500 Loss: 10.06456851

In [296]:
state_dict_init_f = model_init_f.state_dict()
state_dict_init_g = model_init_g.state_dict()

###### Dorseuil

In [297]:
# print('training f')
# gaussian_transport_dataloader = DataLoader(gaussian_transport_dataset, batch_size=250, shuffle=True)
# train_wasserstein(model_init_f, gaussian_transport_dataloader, lr=0.1, epochs=10, init_z = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2)

In [298]:
# X, Y, C = gaussian_dataset.X, gaussian_dataset.Y, gaussian_dataset.C
# #Calcul de la dérivée du PICNN

# for test in range(20):
#     x_i = X[test, :, :]
#     y_i = Y[test, :, :]
#     c_i = C[test, :, :]

#     locs = c_i[:,0]
#     #print(locs)

#     scales = c_i[:,1]
#     #print(scales)  


#     y_i.requires_grad_(True)
#     x_i.requires_grad_(True)
#     #c_i.requires_grad_(True)    

#     output_model_f = model_init_f(x_i, c_i)
#     grad_model_f = torch.autograd.grad(outputs=output_model_f, inputs=x_i, grad_outputs=torch.ones_like(output_model_f), create_graph=True)[0].detach().numpy()

#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True)
#     plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True)
#     plt.hist(grad_model_f[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()


#     output_model_g = model_init_g(y_i, c_i)
#     grad_model_g = torch.autograd.grad(outputs=output_model_g, inputs=y_i, grad_outputs=torch.ones_like(output_model_g), create_graph=True)[0].detach().numpy()
#     plt.hist(X[test, :, 0],  bins=15, label = 'X', density = True, color = 'red')
#     #plt.hist(Y[test, :, 0],  bins=15, label = 'Y', density = True, color = 'blue')
#     plt.hist(grad_model_g[:, 0],  bins=15, label = 'grad_model', density = True, alpha = 0.5)
#     # plt.hist(X_pred,  bins=15, label = 'X_pred', density = True, alpha = 0.5)
#     interval_x = np.linspace(-3, 3, 300)
#     interval_y = np.linspace(-3*scales[0] + locs[0], 3*scales[0] + locs[0], 300)

#     plt.plot(interval_x, stats.norm.pdf(interval_x, loc=0, scale=1), label = 'X_distrib', color = 'blue')
#     #plt.plot(interval_y, stats.norm.pdf(interval_y, loc = locs[0], scale = scales[0]), label = 'Y_distrib', color = 'orange')

#     plt.legend()
#     plt.show()

## __Makkuva__

In [299]:
# state_dict_init_f = torch.load('trained_models/training14/models/model_f_0.pth')
# state_dict_init_g = torch.load('trained_models/training14/models/model_g_0.pth')

In [311]:
ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

old_ICNNf = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')
old_ICNNg = ICNNet(layer_sizes = layer_sizes, context_layer_sizes=context_layer_sizes, init_bunne = 'TR')

# Load the state dictionary into ICNNf and ICNNg
ICNNf.load_state_dict(state_dict_init_f)
ICNNg.load_state_dict(state_dict_init_g)

old_ICNNf.load_state_dict(state_dict_init_f)
old_ICNNg.load_state_dict(state_dict_init_g)

l_ICNNf = [ICNNf, old_ICNNf]
l_ICNNg = [ICNNg, old_ICNNg]

In [312]:
n_points = 2000
test = 21

filepath_pth_f = 'trained_models/training14/models/model_f_'
filepath_pth_g = 'trained_models/training14/models/model_g_'

filepath_plt_f = 'trained_models/training14/plots/model_f_'
filepath_plt_g = 'trained_models/training14/plots/model_g_'

In [313]:
filename_pth_f = filepath_pth_f + str(0) + '.pth'
filename_pth_g = filepath_pth_g + str(0) + '.pth'
torch.save(ICNNf.state_dict(), filename_pth_f)
torch.save(ICNNg.state_dict(), filename_pth_g)

In [314]:
filename_plt_f = filepath_plt_f + str(0) + '.png'
filename_plt_g = filepath_plt_g + str(0) + '.png'
plot_transport(dataset, test, ICNNf, ICNNg, init_z_f = init_z_f, init_z_g = init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)

<Figure size 640x480 with 0 Axes>

In [315]:
dataloader = DataLoader(dataset, batch_size=501, shuffle=True)

loss_f = list()
loss_g = list()

prev_param_f = [param.clone().detach() for param in ICNNf.parameters()]
prev_param_g = [param.clone().detach() for param in ICNNg.parameters()]

for epoch in range(1, 501) :
    print('epoch :', epoch)
    mean_loss_f, mean_loss_g, prev_param_f, prev_param_g = train_makkuva_epoch(ICNNf=ICNNf, ICNNg=ICNNg, prev_param_f=prev_param_f, prev_param_g=prev_param_g, dataloader = dataloader, init_z_f = init_z_f, init_z_g = init_z_g, lr=0.001, train_freq_g=10, train_freq_f=1, regularize_f = False, regularize_g = True, lambda_proximal=0.0001)
    #mean_loss_f, mean_loss_g = train_makkuva_epoch(ICNNf, ICNNg, None, None, dataloader, init_z_f = lambda x: (1/2) * torch.norm(x, dim=-1, keepdim=True)**2, init_z_g = lambda x: (1/2) * torch.norm(-x, dim=-1, keepdim=True)**2, lr=0.0001, train_freq_g=10, train_freq_f=1, gaussian_transport=False)

    loss_f.append(mean_loss_f)
    loss_g.append(mean_loss_g)

    filename_pth_f = filepath_pth_f + str(epoch) + '.pth'
    filename_pth_g = filepath_pth_g + str(epoch) + '.pth'
    torch.save(l_ICNNf[epoch%2].state_dict(), filename_pth_f)
    torch.save(l_ICNNg[epoch%2].state_dict(), filename_pth_g)

    filename_plt_f = filepath_plt_f + str(epoch) + '.png'
    filename_plt_g = filepath_plt_g + str(epoch) + '.png'
    plot_transport(dataset, test, ICNNf, ICNNg, init_z_f=init_z_f, init_z_g=init_z_g, filename_f = filename_plt_f, filename_g = filename_plt_g, n_points=n_points)  

epoch : 1
R_f 1.964739931281656e-05
proximal_term 0.5196471214294434
loss_g: 5.789059162139893, loss_f: 0.47906094789505005
epoch : 2
R_f 3.323352575534955e-05
proximal_term 0.5196450352668762
loss_g: 5.805776119232178, loss_f: 0.5597701072692871
epoch : 3
R_f 4.9505921197123826e-05
proximal_term 0.5196444988250732
loss_g: 5.781257152557373, loss_f: 0.6664747595787048
epoch : 4
R_f 7.236286182887852e-05
proximal_term 0.5196460485458374
loss_g: 5.6972761154174805, loss_f: 0.7026309967041016
epoch : 5
R_f 9.092772961594164e-05
proximal_term 0.5196494460105896
loss_g: 5.572621822357178, loss_f: 0.5870129466056824
epoch : 6
R_f 0.00010683241998776793
proximal_term 0.5196506381034851
loss_g: 5.521111488342285, loss_f: 0.48737409710884094
epoch : 7
R_f 0.00012394830991979688
proximal_term 0.5196495652198792
loss_g: 5.535539627075195, loss_f: 0.4607309401035309
epoch : 8
R_f 0.00014457738143391907
proximal_term 0.5196478366851807
loss_g: 5.5677032470703125, loss_f: 0.4786405563354492
epoch : 

In [None]:
# for ele in loss_f:
#     print(ele)

# print('stop')

# for ele in loss_g:
#     print(ele)

In [None]:
ICNNf.load_state_dict(torch.load(filepath_pth_f + '0.pth'))
ICNNg.load_state_dict(torch.load(filepath_pth_g + '0.pth'))

filename_plt_f = 'trained_models/training14/plots/model_f_test'
filename_plt_g = 'trained_models/training14/plots/model_g_test'

plot_transport(dataset, test, ICNNf, ICNNg, filename_f = filename_plt_f, filename_g = filename_plt_g)

TypeError: plot_transport() missing 2 required positional arguments: 'init_z_f' and 'init_z_g'