In [10]:
print('..running')
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#from pytorch_model_summary import summary
import yaml
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from util import samples_generated, samples_real, plot_curve
import idf
from train import evaluation, training 
from data import load_data
from neural_networks import nnetts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 100
train_data, val_data, test_data = load_data('mnist')
# Create data loaders
training_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

result_dir = 'results/exp_test'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'idf-4'

D = 784   # input dimension
M = D  # the number of neurons in scale (s) and translation (t) nets
lr = 1e-3 # learning rate
num_epochs = 100 # max. number of epochs
max_patience = min(num_epochs//20,5) # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped
num_flows = 4 # The number of invertible transformations
lam = 0. # Regularization Hyperparameter
n_mixtures = 5 # Number of latent mixing variables

hyperparameters = {'D': D, 
                   'M': M,
                   'lr': lr,
                   'num_epochs': num_epochs,
                   'max_patience': max_patience,
                   'num_flows': num_flows,
                   'batch_size': batch_size,
                   'lambda': lam,
                   'n_mixtures': n_mixtures
                    }

with open(result_dir + '/hyperparameters.yaml', 'w') as file:
    yaml.dump(hyperparameters, file)

netts = nnetts(D, M)
model = idf.IDF4(netts, num_flows, n_mixtures = n_mixtures, D=D).to(device)
#print(summary(model, torch.zeros(1, 64), show_input=False, show_hierarchical=False))
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)
# Training procedure
nll_val = training(name=name, result_dir = result_dir, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader, device=device, lam=lam)

with open(result_dir + '/val_loss.txt', "w") as file:
    for item in nll_val:
        file.write(f"{item}\n")

test_loss = evaluation(name=result_dir + '/' + name, test_loader=test_loader)
f = open(result_dir + '/test_loss.txt', "w")
f.write(str(test_loss))
f.close()

#samples_generated(result_dir + '/' + name, test_loader, 28)
#plot_curve(result_dir + '/' + name, nll_val)

..running
Epoch: 0, train nll=1468.9906005859375, val nll=1469.1412864583333
saved!
Epoch: 1, train nll=1089.5665283203125, val nll=1092.6356575520833
saved!
Epoch: 2, train nll=782.2622680664062, val nll=784.7627434895834
saved!
Epoch: 3, train nll=562.2422485351562, val nll=570.44330078125
saved!
Epoch: 4, train nll=531.15576171875, val nll=519.812802734375
saved!
Epoch: 5, train nll=454.41583251953125, val nll=456.92081770833335
saved!
Epoch: 6, train nll=526.9252319335938, val nll=523.189998046875
Epoch: 7, train nll=525.2789916992188, val nll=535.4600208333334
Epoch: 8, train nll=556.8896484375, val nll=560.9361516927083
Epoch: 9, train nll=552.8855590820312, val nll=559.79312109375
Epoch: 10, train nll=544.7901000976562, val nll=548.098865234375
Epoch: 11, train nll=581.2445678710938, val nll=567.8503971354166


KeyboardInterrupt: 

In [9]:
import numpy as np
import matplotlib.pyplot as plt
x = next(iter(test_loader)).detach().numpy()

# GENERATIONS-------
model_best = torch.load(result_dir + '/' + name + '.model')
model_best.eval()

num_x = 1
num_y = 1
x = model_best.sample(num_x * num_y)
x = x.detach().numpy()

fig, ax = plt.subplots(num_x, num_y)
plottable_image = np.reshape(x, (28, 28))
ax.imshow(plottable_image, cmap='gray')
ax.axis('off')

plt.savefig(result_dir + '/' + name + '_generated_images' + '.pdf', bbox_inches='tight')
plt.close()