In [1]:
print('..running')
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pytorch_model_summary import summary
import yaml
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from util import samples_generated, samples_real, plot_curve
import idf
from train import evaluation, training 
from data import load_data
from neural_networks import nnetts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 1000
train_data, val_data, test_data = load_data('mnist')
# Create data loaders
training_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

result_dir = 'results/exp_test'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'idf-4'

D = 784   # input dimension
M = D  # the number of neurons in scale (s) and translation (t) nets
lr = 1e-3 # learning rate
num_epochs = 10 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped
num_flows = 4 # The number of invertible transformations
lam = 0. # Regularization Hyperparameter

hyperparameters = {'D': D, 
                   'M': M,
                   'lr': lr,
                   'num_epochs': num_epochs,
                   'max_patience': max_patience,
                   'num_flows': num_flows,
                   'batch_size': batch_size,
                   'lambda': lam
                    }

with open(result_dir + '/hyperparameters.yaml', 'w') as file:
    yaml.dump(hyperparameters, file)

netts = nnetts(D, M)
model = idf.IDF4(netts, num_flows, D=D).to(device)
#print(summary(model, torch.zeros(1, 64), show_input=False, show_hierarchical=False))
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)
# Training procedure
nll_val = training(name=name, result_dir = result_dir, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader, device=device, lam=lam)

with open(result_dir + '/train_loss.txt', "w") as file:
    for item in nll_val:
        file.write(f"{item}\n")

test_loss = evaluation(name=result_dir + '/' + name, test_loader=test_loader)
f = open(result_dir + '/test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_generated(result_dir + '/' + name, test_loader, 28)
plot_curve(result_dir + '/' + name, nll_val)

..running
Epoch: 0, train nll=3728.421875
val nll=1863.9881041666667
saved!
Epoch: 1, train nll=3706.7265625
val nll=1853.1076666666668
saved!
Epoch: 2, train nll=3684.103759765625
val nll=1841.9745
saved!
Epoch: 3, train nll=3661.23193359375
val nll=1830.4989375
saved!
Epoch: 4, train nll=3637.6728515625
val nll=1818.72975
saved!
Epoch: 5, train nll=3613.4111328125
val nll=1806.5425
saved!
Epoch: 6, train nll=3588.4775390625
val nll=1794.0138333333334
saved!
Epoch: 7, train nll=3562.36328125
val nll=1781.0469166666667
saved!
Epoch: 8, train nll=3536.0390625
val nll=1767.7661458333334
saved!
Epoch: 9, train nll=3511.66552734375
val nll=1754.1437916666666
saved!
FINAL LOSS: nll=1754.1552125
