In [None]:
a = 5

In [3]:
from collections import OrderedDict
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.ae import MNISTAutoencoder
from models.flow import Flow
from models.generator import MNISTGenerator
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std, MNISTFlowDataset
from utils.run_manager import RunBuilder

if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')


params = OrderedDict(
    lr = [0.001],
    batch_size = [32],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [2], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)


ae = torch.load('trained_models/mnist_ae_zdim_2.model')

train_set = MNISTFlowDataset(path='data/MNIST/latent')

run_count = 0
models = []

run_data = []

data_load_time = 0
forward_time = 0
for run in RunBuilder.get_runs(params):
  
    run_count += 1
    device = torch.device(run.device)
    
#     model = Flow()
#     model = model.to(device)

    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    optimizer = torch.optim.Adam(model.parameters(), lr=run.lr)
    num_batches = len(train_set)/run.batch_size
    criterion = run.loss_func()
    
    for epoch in range(100):
        total_recons_loss = 0
        batch_count = 0
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
#         run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        
        
        for batch in tqdm(loader):
            batch_count +=1
            optimizer.zero_grad()

            X = batch
            X = X.to(device=run.device)
            loss = model.loss(X)
            
            loss.backward()
            optimizer.step()
            
            total_recons_loss += loss.item()
            
            
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
        run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        display(df2)
            
        torch.save(model, 'trained_models/mnist_flow_zdim_2.model'.format(run.lr,run.beta, run.z_dim))
    models.append(model)

Unnamed: 0,run_count,epoch,data_fidelity,batch_size,lr,device,z_dim
0,1,0,0.424574,32,0.001,cuda,2
1,1,1,0.418779,32,0.001,cuda,2
2,1,2,0.419978,32,0.001,cuda,2
3,1,3,0.418247,32,0.001,cuda,2
4,1,4,0.417236,32,0.001,cuda,2
5,1,5,0.419768,32,0.001,cuda,2
6,1,6,0.419205,32,0.001,cuda,2
7,1,7,0.417698,32,0.001,cuda,2
8,1,8,0.418449,32,0.001,cuda,2
9,1,9,0.415519,32,0.001,cuda,2


 86%|████████▋ | 1620/1875 [00:31<00:04, 51.72it/s]


KeyboardInterrupt: 

In [3]:
ae = torch.load('trained_models/mnist_ae_zdim_2.model')

ae.eval()
ae = ae.to('cpu')

ae_set = MNISTDataset(path='data/MNIST/processed', normalize=True)

latent = ae.encoder(ae_set.dataset.unsqueeze(dim=1))
lat = latent.detach().clone()

torch.save((lat, ae_set.labels), 'data/MNIST/latent/training.pt')

  self.dataset = torch.tensor(self.dataset, dtype=torch.float32)
