In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from scipy import linalg

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
from tqdm import tnrange
import torchvision
import seaborn as sns
import pandas as pd
import numpy as np
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import sys

In [2]:
sys.path.insert(0, '../models/')
sys.path.insert(0, '../losses/')
sys.path.insert(0, '../metrics/')

In [3]:
from generative_metric import compute_generative_metric
from CVAE_first import CVAE
from CVAE_first import idx2onehot
from sample import Sample
from ELBO import calculate_loss
from inception import InceptionV3
from calculate_fid import get_activations, calculate_frechet_distance
from blur import calc_blur

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transforms = transforms.Compose([transforms.ToTensor()])

In [5]:
from itertools import product
from train import train
from test import test

In [6]:
parameters = dict(lr = [ 0.01, 0.001], batch_size = [50, 100])
param_values = [v for v in parameters.values()]

In [7]:
for lr, BATCH_SIZE in product(*param_values):
    print(lr, BATCH_SIZE)

0.01 50
0.01 100
0.001 50
0.001 100


In [8]:
N_EPOCHS = 15          # times to run the model on complete data
INPUT_DIM = 28 * 28 * 3     # size of each input
HIDDEN_DIM = 1024
IMAGE_CHANNELS = 3# hidden dimension
LATENT_DIM = 100        # latent vector dimension
N_CLASSES = 15          # number of classes in the data

In [9]:
best_test_loss = float('inf')

In [10]:
tb = SummaryWriter(log_dir = './runs_3/')

In [11]:
dataframe = {'model': [], 'optimizer': [], 'epoch': [], 'train_losses': [], 'train_rcl_losses': [], 'train_kl_losses': [], 'test_losses': [],
             'test_rcl_losses': [], 'test_kl_losses': [], 'fid_any_color_any_digit': [],
             'fid_color_blue_any_digit': [], 'fid_any_color_digit_0': [], 
             'fid_color_blue_digit_zero': [],
             'blur': []}

data2 = {'model': [], 'optimizer': [],'epoch': [], 'n_dim': [], 'kld_avg_dim': []}

for lr, BATCH_SIZE in product(*param_values):
    print(lr, BATCH_SIZE)
    shuffle = True
    
    train_dataset = datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=transforms)

    test_dataset = datasets.MNIST(
        './data',
        train=False,
        download=True,
        transform=transforms
    )

    train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
    test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=shuffle)
    
    modelt = CVAE(IMAGE_CHANNELS,N_CLASSES, HIDDEN_DIM, LATENT_DIM, )

    optimizert = optim.Adam(modelt.parameters(), lr=lr)


    
    for e in tnrange(N_EPOCHS,desc='Epochs'):

        train_loss, tr_rcl_loss, tr_kld_loss, tr_kl_per_lt = train(modelt, optimizert, 
                                                                  train_iterator, device, BATCH_SIZE)
        test_loss, test_rcl_loss, test_kld_loss, test_kl_per_lt, blur = test(modelt, optimizert, 
                                                                            test_iterator, device, BATCH_SIZE)

        fid1, grid, grid2 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= None, digit_value=None)
        
        tb.add_image('Real image', grid, global_step = e)
        tb.add_image('Generated images', grid2, global_step = e)
        
        fid2, grid, grid2 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= 0, digit_value=None)
        
        tb.add_image('Real image', grid, global_step = e)
        tb.add_image('Generated images', grid2, global_step = e)  
        
        
        fid3, grid, grid2 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= None, digit_value=5)
        
        tb.add_image('Real image', grid, global_step = e)
        tb.add_image('Generated images', grid2, global_step = e)  
        
        fid4, grid, grid2 = compute_generative_metric(test_iterator, modelt, device, LATENT_DIM, BATCH_SIZE, 
                                        color_value= 0, digit_value=5)
        
        tb.add_image('Real image', grid, global_step = e)
        tb.add_image('Generated images', grid2, global_step = e)  
        
        
        train_loss /= len(train_dataset)
        tr_rcl_loss /= len(train_dataset)
        tr_kld_loss /= len(train_dataset)
        test_loss /= len(test_dataset)
        test_rcl_loss /= len(test_dataset)
        test_kld_loss /= len(test_dataset)
        
        tb.add_scalar('Test loss', test_loss, e)
        tb.add_scalar('Train loss', train_loss, e)
        tb.add_scalar('Test KLD loss', test_kld_loss, e)
        tb.add_scalar('Train KLD loss', tr_kld_loss, e)
        tb.add_scalar('Test RCL loss', test_rcl_loss, e)
        tb.add_scalar('Train RCL loss', tr_rcl_loss, e)
        tb.add_scalar('FID any color and digit', fid1, e)
        tb.add_scalar('FID color red only', fid2, e)
        tb.add_scalar('FID digit 5 only', fid3, e)
        tb.add_scalar('FID red 5 only', fid4, e)
        
        tb.add_histogram('encoder.conv1.conv.bias', modelt.encoder.conv1.conv.bias, e)
        tb.add_histogram('encoder.conv1.conv.weight', modelt.encoder.conv1.conv.weight, e)
        tb.add_histogram('encoder.conv1.conv.weight.grad', modelt.encoder.conv1.conv.weight.grad, e)

        tb.add_histogram('encoder.conv2.conv.bias', modelt.encoder.conv2.conv.bias, e)
        tb.add_histogram('encoder.conv2.conv.weight', modelt.encoder.conv2.conv.weight, e)
        tb.add_histogram('encoder.conv2.conv.weight.grad', modelt.encoder.conv2.conv.weight.grad, e)

        tb.add_histogram('encoder.conv3.conv.bias', modelt.encoder.conv3.conv.bias, e)
        tb.add_histogram('encoder.conv3.conv.weight', modelt.encoder.conv3.conv.weight, e)
        tb.add_histogram('encoder.conv3.conv.weight.grad', modelt.encoder.conv3.conv.weight.grad, e)

        tb.add_histogram('encoder.mu.bias', modelt.encoder.mu.bias, e)
        tb.add_histogram('encoder.mu.weight', modelt.encoder.mu.weight, e)
        tb.add_histogram('encoder.mu.weight.grad', modelt.encoder.mu.weight.grad, e)

        tb.add_histogram('encoder.var.bias', modelt.encoder.var.bias, e)
        tb.add_histogram('encoder.var.weight', modelt.encoder.var.weight, e)
        tb.add_histogram('encoder.var.weight.grad', modelt.encoder.var.weight.grad, e)


        tb.add_histogram('decoder.latent_to_hidden.bias', modelt.decoder.latent_to_hidden.bias, e)
        tb.add_histogram('decoder.latent_to_hidden.weight', modelt.decoder.latent_to_hidden.weight, e)
        tb.add_histogram('decoder.latent_to_hidden.weight.grad', modelt.decoder.latent_to_hidden.weight.grad, e)

        tb.add_histogram('decoder.conv1.conv.bias', modelt.decoder.conv1.conv.bias, e)
        tb.add_histogram('decoder.conv1.conv.weight', modelt.decoder.conv1.conv.weight, e)
        tb.add_histogram('decoder.conv1.conv.weight.grad', modelt.decoder.conv1.conv.weight.grad, e)

        tb.add_histogram('decoder.conv2.conv.bias', modelt.decoder.conv2.conv.bias, e)
        tb.add_histogram('decoder.conv2.conv.weight', modelt.decoder.conv2.conv.weight, e)
        tb.add_histogram('decoder.conv2.conv.weight.grad', modelt.decoder.conv2.conv.weight.grad, e)

        tb.add_histogram('decoder.conv3.conv.bias', modelt.decoder.conv3.conv.bias, e)
        tb.add_histogram('decoder.conv3.conv.weight', modelt.decoder.conv3.conv.weight, e)
        tb.add_histogram('decoder.conv3.conv.weight.grad', modelt.decoder.conv3.conv.weight.grad, e)

        
        dataframe['epoch'].append(e)
        dataframe['model'].append(modelt)
        dataframe['optimizer'].append(optimizert)
        dataframe['train_losses'].append(train_loss)
        dataframe['train_rcl_losses'].append(tr_rcl_loss)
        dataframe['train_kl_losses'].append(tr_kld_loss)
        dataframe['test_losses'].append(test_loss)
        dataframe['test_rcl_losses'].append(test_rcl_loss)
        dataframe['test_kl_losses'].append(test_kld_loss)

        dataframe['fid_any_color_any_digit'].append(fid1)
        dataframe['fid_color_blue_any_digit'].append(fid2)
        dataframe['fid_any_color_digit_0'].append(fid3)
        dataframe['fid_color_blue_digit_zero'].append(fid4)
        dataframe['blur'].append(blur)
        print(f'Epoch {e}, Train Loss: {train_loss:.2f}, Train KLD Loss: {tr_kld_loss:.2f}, Test Loss: {test_loss:.2f}, Test KLD Loss: {test_kld_loss:.2f}' )
        #sns.pairplot(data = pd.DataFrame(tr_kl_per_lt), height=3, vars=["dimension"])

        df = pd.DataFrame(tr_kl_per_lt)
        df = df.sort_values(by=['KL_Divergence'])
        n_dim = np.max(df['Latent_Dimension'])

        kld_avg_dim = np.zeros(n_dim)

        for i in range(n_dim):
            kld_avg_dim[i] = np.mean(df['KL_Divergence'][df['Latent_Dimension'] == i])
        kld_avg_dim = np.sort(kld_avg_dim)[::-1]

        data2['epoch'].append(e)
        data2['n_dim'].append(n_dim)
        data2['model'].append(modelt)
        data2['optimizer'].append(optimizert)
        data2['kld_avg_dim'].append(kld_avg_dim)

        #     print(f'Epoch {e}, Train RCL loss: {tr_rcl_loss:.2f}, Test RCL Loss: {test_rcl_loss:.2f}')
        if best_test_loss > test_loss:
            best_test_loss = test_loss
            patience_counter = 1
        else:
            patience_counter += 1

        if patience_counter > 3:
            break
        
    

0.01 50


HBox(children=(IntProgress(value=0, description='Epochs', max=15, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 2040.36, Train KLD Loss: 1871.45, Test Loss: 145.21, Test KLD Loss: 7.18


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 142.58, Train KLD Loss: 8.90, Test Loss: 137.39, Test KLD Loss: 9.48


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 139.29, Train KLD Loss: 9.86, Test Loss: 136.18, Test KLD Loss: 10.23


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 3, Train Loss: 137.14, Train KLD Loss: 10.43, Test Loss: 134.52, Test KLD Loss: 10.07


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 4, Train Loss: 135.30, Train KLD Loss: 10.80, Test Loss: 135.37, Test KLD Loss: 10.88


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 5, Train Loss: 133.45, Train KLD Loss: 10.93, Test Loss: 132.63, Test KLD Loss: 10.97


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 6, Train Loss: 132.97, Train KLD Loss: 11.26, Test Loss: 132.23, Test KLD Loss: 11.57


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 7, Train Loss: 131.51, Train KLD Loss: 11.23, Test Loss: 130.13, Test KLD Loss: 10.62


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 8, Train Loss: 131.58, Train KLD Loss: 11.39, Test Loss: 129.23, Test KLD Loss: 11.30


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 9, Train Loss: 130.75, Train KLD Loss: 11.57, Test Loss: 127.61, Test KLD Loss: 11.22


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 10, Train Loss: 129.56, Train KLD Loss: 11.55, Test Loss: 126.77, Test KLD Loss: 11.09


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 11, Train Loss: 132.74, Train KLD Loss: 11.84, Test Loss: 129.45, Test KLD Loss: 11.80


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 12, Train Loss: 131.23, Train KLD Loss: 11.93, Test Loss: 133.57, Test KLD Loss: 11.91


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 13, Train Loss: 130.74, Train KLD Loss: 11.74, Test Loss: 127.90, Test KLD Loss: 11.67
0.01 100


HBox(children=(IntProgress(value=0, description='Epochs', max=15, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 82639.56, Train KLD Loss: 82441.72, Test Loss: 156.35, Test KLD Loss: 3.96
0.001 50


HBox(children=(IntProgress(value=0, description='Epochs', max=15, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 177.88, Train KLD Loss: 11.44, Test Loss: 120.25, Test KLD Loss: 15.16


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 113.65, Train KLD Loss: 15.65, Test Loss: 107.94, Test KLD Loss: 16.36


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 106.58, Train KLD Loss: 16.56, Test Loss: 104.28, Test KLD Loss: 16.41


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 3, Train Loss: 103.12, Train KLD Loss: 17.04, Test Loss: 101.57, Test KLD Loss: 18.01


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 4, Train Loss: 101.10, Train KLD Loss: 17.28, Test Loss: 100.12, Test KLD Loss: 17.17


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 5, Train Loss: 101.19, Train KLD Loss: 17.63, Test Loss: 99.02, Test KLD Loss: 17.38


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 6, Train Loss: 98.50, Train KLD Loss: 17.57, Test Loss: 97.97, Test KLD Loss: 17.50


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 7, Train Loss: 97.73, Train KLD Loss: 17.56, Test Loss: 97.29, Test KLD Loss: 18.04


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 8, Train Loss: 97.11, Train KLD Loss: 17.61, Test Loss: 96.87, Test KLD Loss: 17.42


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 9, Train Loss: 96.60, Train KLD Loss: 17.64, Test Loss: 97.21, Test KLD Loss: 17.64


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 10, Train Loss: 96.12, Train KLD Loss: 17.68, Test Loss: 96.30, Test KLD Loss: 17.07


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 11, Train Loss: 95.75, Train KLD Loss: 17.72, Test Loss: 96.10, Test KLD Loss: 17.70


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 12, Train Loss: 95.45, Train KLD Loss: 17.75, Test Loss: 95.86, Test KLD Loss: 17.37


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 13, Train Loss: 95.07, Train KLD Loss: 17.76, Test Loss: 95.53, Test KLD Loss: 17.62


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 14, Train Loss: 94.86, Train KLD Loss: 17.79, Test Loss: 95.07, Test KLD Loss: 17.57
0.001 100


HBox(children=(IntProgress(value=0, description='Epochs', max=15, style=ProgressStyle(description_width='initi…

  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 0, Train Loss: 204.38, Train KLD Loss: 10.89, Test Loss: 123.82, Test KLD Loss: 14.70


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 1, Train Loss: 116.01, Train KLD Loss: 15.60, Test Loss: 109.03, Test KLD Loss: 16.58


  batch = Variable(batch, volatile=True)
  "See the documentation of nn.Upsample for details.".format(mode))


Epoch 2, Train Loss: 106.83, Train KLD Loss: 16.74, Test Loss: 103.97, Test KLD Loss: 16.86
