In [1]:
# https://github.com/jmtomczak/git_flow/blob/main/models/idf.py
# https://arxiv.org/pdf/2011.15056

import os
import torch
from torch.utils.data import DataLoader
from pytorch_model_summary import summary
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\{0, 1, \ldots, 16\}$.

In [2]:
from util import samples_generated, samples_real, plot_curve
import idf
from train import evaluation, training 
from data import load_data
from neural_networks import nnetts

train_data, val_data, test_data = load_data()
training_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

result_dir = 'results/'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'idf'

D = 64   # input dimension
M = 256  # the number of neurons in scale (s) and translation (t) nets
lr = 1e-3 # learning rate
num_epochs = 1000 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped
num_flows = 8 # The number of invertible transformations


In [3]:
netts = nnetts(D, M)
# Init IDF
model = idf.IDF(netts, num_flows, D=D)
# Print the summary (like in Keras)
print(summary(model, torch.zeros(1, 64), show_input=False, show_hierarchical=False))

IDF by JT.
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Linear-1            [1, 256]          12,544          12,544
       LeakyReLU-2            [1, 256]               0               0
          Linear-3            [1, 256]          65,792          65,792
       LeakyReLU-4            [1, 256]               0               0
          Linear-5             [1, 16]           4,112           4,112
          Linear-6            [1, 256]          12,544          12,544
       LeakyReLU-7            [1, 256]               0               0
          Linear-8            [1, 256]          65,792          65,792
       LeakyReLU-9            [1, 256]               0               0
         Linear-10             [1, 16]           4,112           4,112
         Linear-11            [1, 256]          12,544          12,544
      LeakyReLU-12            [1, 256]               0           

In [4]:
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

In [5]:
# Training procedure
nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader)

Epoch: 0, val nll=178.46530691964287
saved!
Epoch: 1, val nll=174.52134207589285
saved!
Epoch: 2, val nll=172.11358258928573
saved!
Epoch: 3, val nll=170.31215541294642
saved!
Epoch: 4, val nll=168.84606305803572
saved!
Epoch: 5, val nll=167.5669740513393
saved!
Epoch: 6, val nll=166.548203125
saved!
Epoch: 7, val nll=165.3306166294643
saved!
Epoch: 8, val nll=164.31797293526785
saved!
Epoch: 9, val nll=163.46844029017856
saved!
Epoch: 10, val nll=162.44203683035715
saved!
Epoch: 11, val nll=161.80875279017857
saved!
Epoch: 12, val nll=161.04617606026787
saved!
Epoch: 13, val nll=160.30952287946428
saved!
Epoch: 14, val nll=159.58923549107143
saved!
Epoch: 15, val nll=159.12208844866072
saved!
Epoch: 16, val nll=158.85460658482143
saved!
Epoch: 17, val nll=158.1549609375
saved!
Epoch: 18, val nll=157.35230050223214
saved!
Epoch: 19, val nll=156.98951590401785
saved!
Epoch: 20, val nll=156.41614397321428
saved!
Epoch: 21, val nll=155.9995786830357
saved!
Epoch: 22, val nll=155.709853515

In [33]:
test_loss = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_real(result_dir + name, test_loader)

plot_curve(result_dir + name, nll_val)

FINAL LOSS: nll=142.69848250489375
