## Joint Training of Flow and AE model

In [None]:
from collections import OrderedDict
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.ae import MNISTAutoencoder
from models.flow import Flow
from models.generator import MNISTGenerator
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std, MNISTFlowDataset
from utils.run_manager import RunBuilder

if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')


params = OrderedDict(
    lr = [0.001],
    batch_size = [32],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [2], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)


ae = torch.load('trained_models/mnist_ae_noact.model')

train_set = MNISTFlowDataset(path='data/MNIST/latent')

run_count = 0
models = []

run_data = []

data_load_time = 0
forward_time = 0
for run in RunBuilder.get_runs(params):
  
    run_count += 1
    device = torch.device(run.device)
    
    model = Flow()
    model = model.to(device)

    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    optimizer = torch.optim.Adam(model.parameters(), lr=run.lr)
    num_batches = len(train_set)/run.batch_size

    for epoch in range(100):
        total_recons_loss = 0
        batch_count = 0
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
#         run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        
        
        for batch in tqdm(loader):
            batch_count +=1
            optimizer.zero_grad()

            X = batch
            X = X.to(device=run.device)
            loss = model.loss(X)
            
            loss.backward()
            optimizer.step()
            print(loss)
            
            total_recons_loss += loss.item()
            
            
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
        run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        display(df2)
            
        torch.save(model, 'trained_models/mnist_flow_4_test.model'.format(run.lr,run.beta, run.z_dim))
    models.append(model)

## Flow

In [None]:
from collections import OrderedDict, namedtuple
from itertools import product
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.ae import MNISTAutoencoder
from models.generator import MNISTGenerator
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std
from utils.plot_utils import show_latent
from models.flow import Flow
from utils.run_manager import RunBuilder


if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')


params = OrderedDict(
    lr = [0.001],
    batch_size = [32],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [2], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)

train_set = MNISTDataset(path='data/MNIST/processed', normalize=True)
val_set = MNISTDataset(path='data/MNIST/processed', train=False, normalize=True)

run_count = 0
models = []


run_data = []

data_load_time = 0
forward_time = 0
for run in RunBuilder.get_runs(params):
#     torch.cuda.set_device(run.device)
    
    run_count += 1
    device = torch.device(run.device)
    
    ae = MNISTAutoencoder(latent_dim=run.z_dim)
    ae = ae.to(device)
    flow = Flow()
    flow = flow.to(device)
    
    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    
    val_loader = torch.utils.data.DataLoader(val_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )

    optimizer_ae = torch.optim.Adam(ae.parameters(), lr=run.lr)
    optimizer_flow = torch.optim.Adam(flow.parameters(), lr=run.lr)
    
    num_batches = len(train_set)/run.batch_size
    criterion = run.loss_func()
    num_val_batches = len(val_set)/run.batch_size
    
    for epoch in range(100):
        total_recons_loss = 0
        total_val_loss = 0
        batch_count = 0
        total_flow_loss = 0
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['val_fidelity'] = total_recons_loss/num_batches
        results['flow'] = total_flow_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
#         run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        
#         ae = ae.to(device)
#         flow = flow.to(device)
        
        for batch in tqdm(loader):
            batch_count +=1
            
            # update autoencoder
            optimizer_ae.zero_grad()

            X = batch
            X = X.to(device=run.device)
            z = ae.encoder(X)
            y, _ = flow.forward(z)
            out = ae.decoder(y)
            ae_loss = criterion(out, X)
            ae_loss.backward()
            optimizer_ae.step()
            
            # update flow
            optimizer_flow.zero_grad()
            optimizer_ae.zero_grad()
            z = ae.encoder(X)
            flow_loss = flow.loss(z)
            flow_loss.backward()
            optimizer_flow.step()
            optimizer_ae.step()
            
#             print(' backward time', time.time() - forward_time)
            total_recons_loss += ae_loss.item()
            total_flow_loss += flow_loss.item()
        
        X = val_set.dataset[0:1000].unsqueeze(dim=1)
        X = X.to('cuda')
#         temp = ae.to('cpu')
#         temp_flow = flow.to('cpu')
        latent = ae.encoder(X)
        latent,_ = flow.forward(latent)
        out = ae.decoder(latent)
        latent = latent.to('cpu')

        loss = criterion(out, X)            
        total_val_loss += loss.item()

        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['val_fidelity'] = total_recons_loss
        results['flow'] = total_flow_loss/num_batches
        
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
        run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        show_latent(latent.detach(), val_set.labels[0:1000])
        display(df2)
            
#             m.track_loss(G_adv_loss=losses['beta_kl-divergence'], G_mse_loss=losses[''], D_real_loss=total_D_real, D_fake_loss=total_D_fake, D_real_count=real_count, D_fake_count=fake_count)
#         print(epoch, "total_Gloss:",total_Gloss, "total_Dloss:",total_Dloss, "mse:",total_mse_loss, "adv: ", total_adv_loss)           
#         m.end_epoch()
    torch.save([ae, flow], 'trained_models/mnist_flow_ae_joint_2.model')

Unnamed: 0,run_count,epoch,data_fidelity,val_fidelity,flow,batch_size,lr,device,z_dim
0,1,0,0.674338,1264.383102,-2.870605,32,0.001,cuda,2
1,1,1,0.681847,1278.463477,-7.376571,32,0.001,cuda,2
2,1,2,0.666100,1248.938127,-10.421444,32,0.001,cuda,2
3,1,3,0.674670,1265.007157,-3.624471,32,0.001,cuda,2
4,1,4,0.673245,1262.335154,-10.318977,32,0.001,cuda,2
...,...,...,...,...,...,...,...,...,...
91,1,91,0.663378,1243.834444,1.845046,32,0.001,cuda,2
92,1,92,0.662907,1242.951510,1.241006,32,0.001,cuda,2
93,1,93,0.667557,1251.668871,2.327839,32,0.001,cuda,2
94,1,94,0.652119,1222.723851,1.070565,32,0.001,cuda,2


 45%|████▌     | 845/1875 [00:22<00:25, 39.62it/s]

In [9]:
torch.save([ae, flow], 'trained_models/mnist_flow_ae_joint')

In [10]:
a, b  = torch.load('trained_models/mnist_flow_ae_joint')

In [12]:
b

Flow(
  (nvp): RealNVP(
    (transforms): ModuleList(
      (0): AffineTransform(
        (mlp): MLP(
          (layers): Sequential(
            (0): Linear(in_features=2, out_features=64, bias=True)
            (1): ReLU()
            (2): Linear(in_features=64, out_features=64, bias=True)
            (3): ReLU()
            (4): Linear(in_features=64, out_features=2, bias=True)
          )
        )
      )
      (1): AffineTransform(
        (mlp): MLP(
          (layers): Sequential(
            (0): Linear(in_features=2, out_features=64, bias=True)
            (1): ReLU()
            (2): Linear(in_features=64, out_features=64, bias=True)
            (3): ReLU()
            (4): Linear(in_features=64, out_features=2, bias=True)
          )
        )
      )
      (2): AffineTransform(
        (mlp): MLP(
          (layers): Sequential(
            (0): Linear(in_features=2, out_features=64, bias=True)
            (1): ReLU()
            (2): Linear(in_features=64, out_features=64