In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from scipy import linalg

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
from tqdm import tnrange
import seaborn as sns
import pandas as pd
import numpy as np
from torchsummary import summary
import sys

In [2]:
sys.path.insert(0, '../models/')
sys.path.insert(0, '../losses/')
sys.path.insert(0, '../metrics/')

In [3]:
from generative_metric import compute_generative_metric
from CVAE_first import CVAE
from CVAE_first import idx2onehot
from sample import Sample
from ELBO import calculate_loss
from inception import InceptionV3
from calculate_fid import get_activations, calculate_frechet_distance
from blur import calc_blur

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transforms = transforms.Compose([transforms.ToTensor()])

In [5]:
from itertools import product
from train import train
from test import test

In [6]:
parameters = dict(lr = [0.01, 0.001], batch_size = [10,  100, 1000], 
                  shuffle = [True, False])
param_values = [v for v in parameters.values()]

In [11]:
for lr, BATCH_SIZE, shuffle in product(*param_values):
    print(lr, BATCH_SIZE, shuffle)

0.01 10 True
0.01 10 False
0.01 100 True
0.01 100 False
0.01 1000 True
0.01 1000 False
0.001 10 True
0.001 10 False
0.001 100 True
0.001 100 False
0.001 1000 True
0.001 1000 False


In [8]:
N_EPOCHS = 10           # times to run the model on complete data
INPUT_DIM = 28 * 28 * 3     # size of each input
HIDDEN_DIM = 1024
IMAGE_CHANNELS = 3# hidden dimension
LATENT_DIM = 100        # latent vector dimension
N_CLASSES = 15          # number of classes in the data

In [9]:
best_test_loss = float('inf')

In [12]:
dataframe = {'model': [], 'optimizer': [], 'epoch': [], 'train_losses': [], 'train_rcl_losses': [], 'train_kl_losses': [], 'test_losses': [],
             'test_rcl_losses': [], 'test_kl_losses': [], 'fid_any_color_any_digit': [],
             'fid_color_blue_any_digit': [], 'fid_any_color_digit_0': [], 
             'fid_color_blue_digit_zero': [],
             'blur': []}

data2 = {'model': [], 'optimizer': [],'epoch': [], 'n_dim': [], 'kld_avg_dim': []}

for lr, BATCH_SIZE, shuffle in product(*param_values):
    print(lr, BATCH_SIZE, shuffle)
    
    train_dataset = datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=transforms)

    test_dataset = datasets.MNIST(
        './data',
        train=False,
        download=True,
        transform=transforms
    )

    train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
    test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
    
    modelt = CVAE(IMAGE_CHANNELS,N_CLASSES, HIDDEN_DIM, LATENT_DIM, )

    optimizert = optim.Adam(modelt.parameters(), lr=lr)


    
    for e in tnrange(N_EPOCHS,desc='Epochs'):

        train_loss, tr_rcl_loss, tr_kld_loss, tr_kl_per_lt = train(modelt, optimizert, 
                                                                  train_iterator, device, BATCH_SIZE)
        test_loss, test_rcl_loss, test_kld_loss, test_kl_per_lt, blur = test(modelt, optimizert, 
                                                                            test_iterator, device, BATCH_SIZE)

        fid1 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= None, digit_value=None)
        fid2 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= 1, digit_value=None)
        fid3 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= None, digit_value=0)
        fid4 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= 1, digit_value=0)
        
        train_loss /= len(train_dataset)
        tr_rcl_loss /= len(train_dataset)
        tr_kld_loss /= len(train_dataset)
        test_loss /= len(test_dataset)
        test_rcl_loss /= len(test_dataset)
        test_kld_loss /= len(test_dataset)
        dataframe['epoch'].append(e)
        dataframe['model'].append(modelt)
        dataframe['optimizer'].append(optimizert)
        dataframe['train_losses'].append(train_loss)
        dataframe['train_rcl_losses'].append(tr_rcl_loss)
        dataframe['train_kl_losses'].append(tr_kld_loss)
        dataframe['test_losses'].append(test_loss)
        dataframe['test_rcl_losses'].append(test_rcl_loss)
        dataframe['test_kl_losses'].append(test_kld_loss)

        dataframe['fid_any_color_any_digit'].append(fid1)
        dataframe['fid_color_blue_any_digit'].append(fid2)
        dataframe['fid_any_color_digit_0'].append(fid3)
        dataframe['fid_color_blue_digit_zero'].append(fid4)
        dataframe['blur'].append(blur)
        print(f'Epoch {e}, Train Loss: {train_loss:.2f}, Train KLD Loss: {tr_kld_loss:.2f}, Test Loss: {test_loss:.2f}, Test KLD Loss: {test_kld_loss:.2f}' )
        #sns.pairplot(data = pd.DataFrame(tr_kl_per_lt), height=3, vars=["dimension"])

        df = pd.DataFrame(tr_kl_per_lt)
        df = df.sort_values(by=['KL_Divergence'])
        n_dim = np.max(df['Latent_Dimension'])

        kld_avg_dim = np.zeros(n_dim)

        for i in range(n_dim):
            kld_avg_dim[i] = np.mean(df['KL_Divergence'][df['Latent_Dimension'] == i])
        kld_avg_dim = np.sort(kld_avg_dim)[::-1]

        data2['epoch'].append(e)
        data2['n_dim'].append(n_dim)
        data2['model'].append(modelt)
        data2['optimizer'].append(optimizert)
        data2['kld_avg_dim'].append(kld_avg_dim)

        #     print(f'Epoch {e}, Train RCL loss: {tr_rcl_loss:.2f}, Test RCL Loss: {test_rcl_loss:.2f}')
        if best_test_loss > test_loss:
            best_test_loss = test_loss
            patience_counter = 1
        else:
            patience_counter += 1

        if patience_counter > 3:
            break
        
    

0.01 10 True


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 421.85, Train KLD Loss: 120.99, Test Loss: 740.46, Test KLD Loss: 5.28


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 764.16, Train KLD Loss: 4.58, Test Loss: 645.95, Test KLD Loss: 3.93


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 539.87, Train KLD Loss: 8.37, Test Loss: 439.94, Test KLD Loss: 14.49


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 3, Train Loss: 445.88, Train KLD Loss: 7.18, Test Loss: 442.88, Test KLD Loss: 4.83


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 4, Train Loss: 500.78, Train KLD Loss: 5.07, Test Loss: 519.42, Test KLD Loss: 2.71


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 5, Train Loss: 590.81, Train KLD Loss: 5.57, Test Loss: 719.55, Test KLD Loss: 8.81
0.01 10 False


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 62298776.88, Train KLD Loss: 62298541.88, Test Loss: 328.42, Test KLD Loss: 7.72


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 27211360658001585374096881876992.00, Train KLD Loss: 27211360658001585374096881876992.00, Test Loss: 463.20, Test KLD Loss: 66.02


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 396.76, Train KLD Loss: 76.07, Test Loss: 470.69, Test KLD Loss: 176.61


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 3, Train Loss: 328.93, Train KLD Loss: 23.42, Test Loss: 385.43, Test KLD Loss: 61.00
0.01 100 True


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 247749.73, Train KLD Loss: 247564.44, Test Loss: 145.22, Test KLD Loss: 7.07


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 139.83, Train KLD Loss: 8.63, Test Loss: 135.88, Test KLD Loss: 9.47


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 133.50, Train KLD Loss: 10.04, Test Loss: 131.90, Test KLD Loss: 10.19


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 3, Train Loss: 130.24, Train KLD Loss: 10.95, Test Loss: 127.66, Test KLD Loss: 11.44


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 4, Train Loss: 129.36, Train KLD Loss: 11.68, Test Loss: 127.09, Test KLD Loss: 11.27


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 5, Train Loss: 126.20, Train KLD Loss: 11.73, Test Loss: 125.50, Test KLD Loss: 11.59


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 6, Train Loss: 125.13, Train KLD Loss: 11.95, Test Loss: 124.38, Test KLD Loss: 11.91


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 7, Train Loss: 124.70, Train KLD Loss: 12.12, Test Loss: 122.78, Test KLD Loss: 12.22


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 8, Train Loss: 123.96, Train KLD Loss: 12.15, Test Loss: 122.52, Test KLD Loss: 11.93


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 9, Train Loss: 123.52, Train KLD Loss: 12.27, Test Loss: 123.05, Test KLD Loss: 12.02
0.01 100 False


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 75236.91, Train KLD Loss: 75022.98, Test Loss: 164.25, Test KLD Loss: 6.34


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 148.96, Train KLD Loss: 6.31, Test Loss: 146.68, Test KLD Loss: 7.13
0.01 1000 True


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…





  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 6891914.73, Train KLD Loss: 6891459.31, Test Loss: 212.61, Test KLD Loss: 7.90
0.01 1000 False


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…





  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 986541.86, Train KLD Loss: 986100.43, Test Loss: 205.32, Test KLD Loss: 7.01
0.001 10 True


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 141.64, Train KLD Loss: 12.87, Test Loss: 114.77, Test KLD Loss: 15.01


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 112.95, Train KLD Loss: 15.54, Test Loss: 108.86, Test KLD Loss: 15.05


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 108.41, Train KLD Loss: 16.01, Test Loss: 109.14, Test KLD Loss: 15.47


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 3, Train Loss: 106.71, Train KLD Loss: 16.32, Test Loss: 104.74, Test KLD Loss: 16.53


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 4, Train Loss: 105.30, Train KLD Loss: 16.50, Test Loss: 103.97, Test KLD Loss: 15.83


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 5, Train Loss: 104.16, Train KLD Loss: 16.67, Test Loss: 103.89, Test KLD Loss: 16.46


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 6, Train Loss: 103.45, Train KLD Loss: 16.82, Test Loss: 103.56, Test KLD Loss: 16.71


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 7, Train Loss: 102.98, Train KLD Loss: 16.92, Test Loss: 102.29, Test KLD Loss: 16.78


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 8, Train Loss: 102.51, Train KLD Loss: 17.02, Test Loss: 102.58, Test KLD Loss: 17.37


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 9, Train Loss: 102.11, Train KLD Loss: 17.06, Test Loss: 102.28, Test KLD Loss: 17.18
0.001 10 False


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 139.66, Train KLD Loss: 13.00, Test Loss: 118.67, Test KLD Loss: 15.83


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 113.09, Train KLD Loss: 15.81, Test Loss: 112.80, Test KLD Loss: 15.35


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 108.17, Train KLD Loss: 16.23, Test Loss: 110.75, Test KLD Loss: 15.58
0.001 100 True


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 193.67, Train KLD Loss: 10.69, Test Loss: 124.23, Test KLD Loss: 13.72
0.001 100 False


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 204.55, Train KLD Loss: 8.57, Test Loss: 136.05, Test KLD Loss: 12.76
0.001 1000 True


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…





  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 530.55, Train KLD Loss: 6.64, Test Loss: 291.91, Test KLD Loss: 4.95
0.001 1000 False


HBox(children=(IntProgress(value=0, description='Epochs', max=10, style=ProgressStyle(description_width='initi…





  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 537.93, Train KLD Loss: 8.66, Test Loss: 264.77, Test KLD Loss: 10.23


In [10]:
stats = pd.DataFrame(dataframe)