In [1]:
from collections import OrderedDict
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.ae import MNISTAutoencoder
from models.flow import Flow
from models.generator import MNISTGenerator
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std, MNISTFlowDataset
from utils.run_manager import RunBuilder

if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')


params = OrderedDict(
    lr = [0.001],
    batch_size = [32],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [2], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)


ae = torch.load('trained_models/mnist_ae_noact.model')

train_set = MNISTFlowDataset(path='data/MNIST/latent')

run_count = 0
models = []

run_data = []

data_load_time = 0
forward_time = 0
for run in RunBuilder.get_runs(params):
  
    run_count += 1
    device = torch.device(run.device)
    
    model = Flow()
    model = model.to(device)

    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    optimizer = torch.optim.Adam(model.parameters(), lr=run.lr)
    num_batches = len(train_set)/run.batch_size

    for epoch in range(100):
        total_recons_loss = 0
        batch_count = 0
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
#         run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        
        
        for batch in tqdm(loader):
            batch_count +=1
            optimizer.zero_grad()

            X = batch
            X = X.to(device=run.device)
            loss = model.loss(X)
            
            loss.backward()
            optimizer.step()
            print(loss)
            
            total_recons_loss += loss.item()
            
            
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
        run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        display(df2)
            
        torch.save(model, 'trained_models/mnist_flow_4_test.model'.format(run.lr,run.beta, run.z_dim))
    models.append(model)

Unnamed: 0,run_count,epoch,data_fidelity,batch_size,lr,device,z_dim
0,1,0,16.876468,32,0.001,cuda,2


  0%|          | 8/1875 [00:00<01:11, 26.28it/s]

tensor(8.1964, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2564, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.5899, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.7277, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2776, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2131, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3671, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3328, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0819, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9236, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6015, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9022, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8349, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8127, device='cuda:0', grad_fn=<NegBackward>)


  1%|▏         | 24/1875 [00:00<00:34, 54.30it/s]

tensor(8.0585, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2047, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7724, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6091, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9247, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2974, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9581, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5352, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6620, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8237, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3587, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2665, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6789, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0348, device='cuda:0', grad_fn=<NegBackward>)


  2%|▏         | 37/1875 [00:00<00:35, 52.09it/s]

tensor(7.7613, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.7381, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7740, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8512, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2386, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4837, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8797, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6646, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6747, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1694, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1869, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2137, device='cuda:0', grad_fn=<NegBackward>)


  3%|▎         | 51/1875 [00:01<00:31, 58.57it/s]

tensor(8.6535, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4447, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8562, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8036, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3803, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3047, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8563, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1885, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0681, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7450, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9415, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9888, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0115, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5854, device='cuda:0', grad_fn=<NegBackward>)


  3%|▎         | 65/1875 [00:01<00:29, 60.55it/s]

tensor(8.0527, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0874, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1150, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9204, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8707, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6367, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1435, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6625, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1781, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3589, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0352, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5475, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5154, device='cuda:0', grad_fn=<NegBackward>)


  4%|▍         | 80/1875 [00:01<00:28, 63.06it/s]

tensor(7.3932, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8177, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3761, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9058, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1072, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4165, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3315, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3960, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5698, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1786, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1516, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0424, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2359, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5316, device='cuda:0', grad_fn=<NegBackward>)


  5%|▍         | 87/1875 [00:01<00:29, 61.47it/s]

tensor(8.6662, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5991, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0421, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3544, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3601, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7387, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5487, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7642, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5476, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2122, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7324, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3115, device='cuda:0', grad_fn=<NegBackward>)


  5%|▌         | 102/1875 [00:01<00:27, 65.55it/s]

tensor(7.0887, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3733, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4625, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0864, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5476, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5417, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0644, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0147, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2127, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7710, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3974, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4765, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4122, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4602, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0184, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9908, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7068, device='cuda:0', grad_fn=<NegBackward>)


  6%|▋         | 119/1875 [00:02<00:26, 65.17it/s]

tensor(7.6648, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6789, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9864, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6941, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7509, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0151, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7426, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2475, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0321, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5388, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7891, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2991, device='cuda:0', grad_fn=<NegBackward>)


  7%|▋         | 135/1875 [00:02<00:24, 70.67it/s]

tensor(7.3560, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3453, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8735, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3102, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2566, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3132, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1292, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3417, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9085, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7697, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1288, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.6043, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0283, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3589, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4480, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9249, device='cuda:0', grad_fn=<NegBackward>)


  8%|▊         | 150/1875 [00:02<00:25, 68.24it/s]

tensor(7.3823, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3063, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3776, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.5028, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4659, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4966, device='cuda:0', grad_fn=<NegBackward>)
tensor(9.0656, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8895, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6042, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4406, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8952, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3257, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4937, device='cuda:0', grad_fn=<NegBackward>)


  9%|▉         | 165/1875 [00:02<00:24, 69.35it/s]

tensor(8.2347, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0881, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2724, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.6979, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0478, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3919, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9001, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4332, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1777, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8118, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4641, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9385, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1898, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0673, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0628, device='cuda:0', grad_fn=<NegBackward>)


 10%|▉         | 181/1875 [00:02<00:23, 72.39it/s]

tensor(8.3747, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6573, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.8234, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.6687, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4232, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6888, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4743, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8602, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8673, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8022, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4638, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3970, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6472, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5853, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0372, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3999, device='cuda:0', grad_fn=<NegBackward>)


 11%|█         | 199/1875 [00:03<00:20, 80.03it/s]

tensor(7.4621, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8149, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5632, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1390, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8504, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5369, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4282, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.7506, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9830, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3058, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7986, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3457, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5976, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4698, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4046, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5361, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5663, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1695, device='cuda:0', grad_fn=<NegBackward>)


 11%|█         | 209/1875 [00:03<00:20, 82.83it/s]

tensor(7.4404, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0798, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3286, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9943, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5840, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4064, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9810, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7250, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5293, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7870, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2510, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6193, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3615, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0030, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2529, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9473, device='cuda:0', grad_fn=<NegBackward>)


 12%|█▏        | 226/1875 [00:03<00:21, 76.11it/s]

tensor(7.7279, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5232, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0136, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9067, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7203, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4498, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.7169, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5201, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2932, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5653, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8591, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0403, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5022, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8213, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4066, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4582, device='cuda:0', grad_fn=<NegBackward>)


 13%|█▎        | 243/1875 [00:03<00:20, 79.41it/s]

tensor(7.6125, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0949, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5588, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8258, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1604, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7075, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1242, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2421, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1568, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5389, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0643, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3108, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5637, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0265, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5953, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2059, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3071, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9333, device='cuda:0', grad_fn=<NegBackward>)


 14%|█▍        | 261/1875 [00:03<00:19, 83.04it/s]

tensor(8.0833, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1913, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7958, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9417, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3548, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9785, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0905, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1878, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4952, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2658, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4739, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4155, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4893, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6684, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2295, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5024, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2205, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4543, device='cuda:0', grad_fn=<NegBackward>)


 15%|█▍        | 279/1875 [00:04<00:18, 85.10it/s]

tensor(7.2939, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4306, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7486, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5779, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2495, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9812, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3375, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4709, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0223, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3683, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3765, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1966, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6161, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7580, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2927, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2483, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6443, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4018, device='cuda:0', grad_fn=<NegBackward>)


 16%|█▌        | 297/1875 [00:04<00:20, 76.38it/s]

tensor(7.4406, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4475, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2320, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1652, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6029, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4361, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4462, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4758, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6231, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4741, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5090, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4247, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6362, device='cuda:0', grad_fn=<NegBackward>)


 17%|█▋        | 314/1875 [00:04<00:20, 76.71it/s]

tensor(7.8688, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.9371, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0429, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7565, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9991, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2266, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3774, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4355, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1954, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4253, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7429, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5172, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1006, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2902, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1799, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8426, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3821, device='cuda:0', grad_fn=<NegBackward>)


 18%|█▊        | 332/1875 [00:04<00:19, 80.27it/s]

tensor(7.2050, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2743, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1049, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2906, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2687, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3277, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9531, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4601, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8677, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2472, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4366, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1515, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7180, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.5708, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9295, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3662, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3224, device='cuda:0', grad_fn=<NegBackward>)


 18%|█▊        | 341/1875 [00:04<00:20, 75.28it/s]

tensor(6.7813, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3479, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7227, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9452, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9569, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8692, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3199, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8728, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9970, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3796, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8076, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2648, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8431, device='cuda:0', grad_fn=<NegBackward>)


 19%|█▉        | 358/1875 [00:05<00:20, 74.16it/s]

tensor(7.3460, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5124, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5038, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8334, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2526, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7206, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7508, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7117, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.7197, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6249, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6932, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0853, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6827, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1066, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7556, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7385, device='cuda:0', grad_fn=<NegBackward>)


 20%|██        | 375/1875 [00:05<00:20, 72.14it/s]

tensor(8.5299, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7322, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4782, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2884, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1557, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7064, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3133, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8264, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3909, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3950, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3276, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3283, device='cuda:0', grad_fn=<NegBackward>)
tensor(9.3552, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0187, device='cuda:0', grad_fn=<NegBackward>)


 20%|██        | 383/1875 [00:05<00:22, 67.65it/s]

tensor(7.7292, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9515, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7046, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5832, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5126, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6792, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0336, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2478, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7142, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8337, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1656, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3775, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9176, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4016, device='cuda:0', grad_fn=<NegBackward>)


 21%|██▏       | 400/1875 [00:05<00:20, 73.58it/s]

tensor(7.2996, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0951, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0982, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1685, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5649, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8910, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7984, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8391, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5069, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1484, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2397, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0138, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2987, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.6553, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5203, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1023, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5998, device='cuda:0', grad_fn=<NegBackward>)


 22%|██▏       | 418/1875 [00:06<00:18, 79.42it/s]

tensor(8.2430, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8469, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1974, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3828, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3435, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6855, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6753, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6039, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9042, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4777, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3939, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3130, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6171, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6956, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4527, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5804, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5037, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0635, device='cuda:0', grad_fn=<NegBackward>)


 23%|██▎       | 436/1875 [00:06<00:17, 80.87it/s]

tensor(7.6166, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4900, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6644, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5254, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9838, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5313, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0411, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2541, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4421, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7104, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4290, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9772, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2789, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1230, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4162, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5088, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5690, device='cuda:0', grad_fn=<NegBackward>)


 24%|██▍       | 453/1875 [00:06<00:20, 69.26it/s]

tensor(8.1544, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5449, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3083, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.8432, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8312, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2364, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7701, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6638, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3946, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6395, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6096, device='cuda:0', grad_fn=<NegBackward>)


 25%|██▌       | 470/1875 [00:06<00:18, 74.47it/s]

tensor(7.4933, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0259, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9688, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2231, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9589, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8290, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0285, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1432, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9835, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4541, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7477, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1652, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5746, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1654, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8988, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7348, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6065, device='cuda:0', grad_fn=<NegBackward>)


 26%|██▌       | 488/1875 [00:06<00:17, 79.51it/s]

tensor(7.6377, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2667, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5344, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7801, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1187, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8884, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2805, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5774, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3354, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1750, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3044, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2109, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0707, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3429, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1861, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0982, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1042, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3154, device='cuda:0', grad_fn=<NegBackward>)


 27%|██▋       | 497/1875 [00:07<00:17, 81.06it/s]

tensor(6.8765, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7307, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4617, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7552, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7144, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4601, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1724, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7761, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2231, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.6297, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2278, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4277, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3424, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4457, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0996, device='cuda:0', grad_fn=<NegBackward>)


 27%|██▋       | 514/1875 [00:07<00:20, 65.67it/s]

tensor(7.8729, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3629, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3905, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3652, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4834, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3110, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3147, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3020, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5213, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3156, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2952, device='cuda:0', grad_fn=<NegBackward>)


 28%|██▊       | 521/1875 [00:07<00:22, 60.46it/s]

tensor(7.5861, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3958, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3789, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9017, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2468, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2196, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2166, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6110, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2686, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2984, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0149, device='cuda:0', grad_fn=<NegBackward>)


 29%|██▊       | 536/1875 [00:07<00:21, 63.59it/s]

tensor(7.9838, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.6843, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4676, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6588, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7153, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3713, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3038, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8092, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9339, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0777, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6668, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4883, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8768, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3360, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5630, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2742, device='cuda:0', grad_fn=<NegBackward>)


 30%|██▉       | 554/1875 [00:07<00:17, 73.64it/s]

tensor(7.0112, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8218, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2460, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8564, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6021, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2387, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5557, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4957, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8362, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4039, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5814, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7845, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2623, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1025, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2135, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3904, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4851, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8614, device='cuda:0', grad_fn=<NegBackward>)


 31%|███       | 572/1875 [00:08<00:16, 79.04it/s]

tensor(7.8139, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6733, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4529, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0096, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4186, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0118, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4179, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1905, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9834, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5911, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4696, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0751, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3402, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4654, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3500, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6563, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2800, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3226, device='cuda:0', grad_fn=<NegBackward>)


 31%|███▏      | 590/1875 [00:08<00:15, 81.22it/s]

tensor(7.6108, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.9579, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0091, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1153, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1694, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2064, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4849, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9403, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1352, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5661, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9141, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6117, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2489, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2232, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1926, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1804, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8050, device='cuda:0', grad_fn=<NegBackward>)


 32%|███▏      | 607/1875 [00:08<00:17, 72.49it/s]

tensor(7.2055, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3696, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3610, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4254, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0208, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3535, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9865, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8045, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1484, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8343, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3826, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2051, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2550, device='cuda:0', grad_fn=<NegBackward>)


 33%|███▎      | 615/1875 [00:08<00:21, 59.66it/s]

tensor(7.4654, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0396, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0090, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2582, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2803, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.2330, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5968, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5770, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7619, device='cuda:0', grad_fn=<NegBackward>)


 34%|███▎      | 629/1875 [00:09<00:20, 60.18it/s]

tensor(7.6220, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1824, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4627, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0882, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8337, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2641, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2650, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8000, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2797, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0910, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0041, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4942, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1283, device='cuda:0', grad_fn=<NegBackward>)


 34%|███▍      | 643/1875 [00:09<00:19, 62.87it/s]

tensor(7.5855, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4594, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.4450, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7133, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6616, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.7191, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1838, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8747, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1337, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0320, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6102, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4951, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6463, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1019, device='cuda:0', grad_fn=<NegBackward>)


 35%|███▌      | 658/1875 [00:09<00:18, 65.81it/s]

tensor(6.9662, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9410, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7129, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4185, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3344, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3902, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4789, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3242, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8087, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1698, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3647, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1973, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3001, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4667, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9697, device='cuda:0', grad_fn=<NegBackward>)


 36%|███▌      | 672/1875 [00:09<00:18, 66.18it/s]

tensor(7.3933, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3127, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9373, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0514, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1349, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.1460, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6850, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.8195, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0417, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9460, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1901, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3449, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3138, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4378, device='cuda:0', grad_fn=<NegBackward>)


 36%|███▌      | 679/1875 [00:09<00:19, 62.64it/s]

tensor(8.0891, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2081, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1750, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2820, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.3909, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7544, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3299, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8885, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7598, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1621, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.0458, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1474, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2436, device='cuda:0', grad_fn=<NegBackward>)


 37%|███▋      | 695/1875 [00:10<00:17, 67.30it/s]

tensor(7.9389, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3995, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9208, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6599, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6272, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.4472, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2910, device='cuda:0', grad_fn=<NegBackward>)
tensor(6.9647, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0609, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8360, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8336, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.7873, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.6694, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5662, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.3918, device='cuda:0', grad_fn=<NegBackward>)


 37%|███▋      | 703/1875 [00:10<00:17, 67.78it/s]

tensor(7.1987, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.2013, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.1765, device='cuda:0', grad_fn=<NegBackward>)
tensor(8.0849, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5305, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.9559, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.5586, device='cuda:0', grad_fn=<NegBackward>)
tensor(7.8054, device='cuda:0', grad_fn=<NegBackward>)


 38%|███▊      | 708/1875 [00:10<00:17, 68.45it/s]


KeyboardInterrupt: 

## Train Semi Supervised Flow

In [None]:
from collections import OrderedDict
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.ae import MNISTAutoencoder
from models.flow import FlowSemiSupervised
from models.generator import MNISTGenerator
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std, MNISTFlowSupervisedDataset
from utils.run_manager import RunBuilder

if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')


params = OrderedDict(
    lr = [0.001],
    batch_size = [5],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [2], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)


ae = torch.load('trained_models/mnist_ae_noact.model')

train_set = MNISTFlowSupervisedDataset(path='data/MNIST/latent')

run_count = 0
models = []

run_data = []

data_load_time = 0
forward_time = 0
for run in RunBuilder.get_runs(params):
  
    run_count += 1
    device = torch.device(run.device)
    
    model = FlowSemiSupervised()
    model = model.to(device)

    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    optimizer = torch.optim.Adam(model.parameters(), lr=run.lr)
    num_batches = len(train_set)/run.batch_size

    for epoch in range(100):
        total_recons_loss = 0
        batch_count = 0
        
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
#         run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        
        
        for batch in tqdm(loader):
            batch_count +=1
            optimizer.zero_grad()

            X, labels = batch
            X = X.to(device=run.device)
            loss = model.loss(X, labels)
            
            loss.backward()
            optimizer.step()
            
#             print('loss', loss.item())
            total_recons_loss += loss.item()
            
            
        results = OrderedDict()
        results['run_count'] = run_count
        results['epoch'] = epoch
        results['data_fidelity'] = total_recons_loss/num_batches
        results['batch_size'] = run.batch_size
        results['lr'] = run.lr
        results['device'] = run.device
        results['z_dim'] = run.z_dim
        
        run_data.append(results)
        df2 = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        display(df2)
            
        torch.save(model, 'trained_models/mnist_flow_semisupervised_1.model')
    models.append(model)

Unnamed: 0,run_count,epoch,data_fidelity,batch_size,lr,device,z_dim
0,1,0,90.553761,5,0.001,cuda,2
1,1,1,72.175842,5,0.001,cuda,2
2,1,2,68.334987,5,0.001,cuda,2
3,1,3,67.877819,5,0.001,cuda,2
4,1,4,68.088346,5,0.001,cuda,2
...,...,...,...,...,...,...,...
85,1,85,95.020182,5,0.001,cuda,2
86,1,86,79.650522,5,0.001,cuda,2
87,1,87,75.742274,5,0.001,cuda,2
88,1,88,73.581813,5,0.001,cuda,2


 44%|████▍     | 5337/12000 [01:22<01:40, 66.37it/s]

In [2]:
multi = torch.distributions.MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
single = torch.distributions.Normal(torch.tensor(0.).to(device), torch.tensor(1.).to(device))

In [20]:
a = torch.randn((2))

multi.log_prob(a.to(device))


tensor(-3.2666, device='cuda:0')

In [21]:
single.log_prob(a.to(device)).sum()

tensor(-3.2666, device='cuda:0')

In [5]:
model.nvp.priors[2].mean

tensor([2.1631, 6.6574], device='cuda:0')

In [23]:
prior = torch.distributions.Normal(torch.tensor(0.).to(device), torch.tensor(1.).to(device))

model.nvp.priors[0].log_prob(torch.randn((1)).to(device))

tensor([-24.8367,  -0.9224], device='cuda:0', dtype=torch.float64)

In [3]:
m = torch.distributions.MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))

In [8]:
m.log_prob(torch.randn((2), dtype=torch.float16).to(device))

tensor(-4.2778, device='cuda:0')

In [3]:
from dataset import GMM
gmm = GMM()

In [6]:
type(gmm.means[0])

list

In [3]:
torch.zeros(2)

tensor([0., 0.])

## Save Latent representation

In [5]:
import torch
from dataset import MNISTDataset
ae = torch.load('trained_models/mnist_ae_noact.model')

ae.eval()
ae = ae.to('cpu')

ae_set = MNISTDataset(path='data/MNIST/processed',train=False, normalize=True)

latent = ae.encoder(ae_set.dataset.unsqueeze(dim=1))
lat = latent.detach().clone()

torch.save((lat, ae_set.labels), 'data/MNIST/latent/test.pt')

In [5]:
import torch
a = torch.tensor([0,1,2])

In [6]:
for i in a:
    print(a[i])

tensor(0)
tensor(1)
tensor(2)
