In [5]:
# https://github.com/jmtomczak/git_flow/blob/main/models/idf.py
# https://arxiv.org/pdf/2011.15056

import os
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from pytorch_model_summary import summary
import yaml
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [6]:
# Define the transformation to flatten the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Custom dataset class to return only images
class MNISTWithoutLabels(datasets.MNIST):
    def __getitem__(self, index):
        img, _ = super().__getitem__(index)  # Ignore the label
        return img

# Download and load the training data
train_dataset = MNISTWithoutLabels(root='./data', train=True, download=True, transform=transform)

# Download and load the test data
test_dataset = MNISTWithoutLabels(root='./data', train=False, download=True, transform=transform)

# Define the validation split size
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Define the batch size
batch_size = 1000

# Create data loaders
training_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\{0, 1, \ldots, 16\}$.

In [11]:
from util import samples_generated, samples_real, plot_curve
import idf
from train import evaluation, training 
from data import load_data
from neural_networks import nnetts

#train_data, val_data, test_data = load_data(name = 'sklearn')
#training_loader = DataLoader(train_data, batch_size=64, shuffle=True)
#val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
#test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

result_dir = 'results/exp_1'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'idf'

D = 784   # input dimension
M = 784  # the number of neurons in scale (s) and translation (t) nets
lr = 1e-3 # learning rate
num_epochs = 100 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped
num_flows = 4 # The number of invertible transformations

hyperparameters = {'D': D, 
                   'M': M,
                   'lr': lr,
                   'num_epochs': num_epochs,
                   'max_patience': max_patience,
                   'num_flows': num_flows,
                   'batch_size': batch_size
                    }

with open(result_dir + '/hyperparameters.yaml', 'w') as file:
    yaml.dump(hyperparameters, file)

In [22]:
netts = nnetts(D, M)
# Init IDF
model = idf.IDF(netts, num_flows, D=D)
# Print the summary (like in Keras)
#print(summary(model, torch.zeros(1, 64), show_input=False, show_hierarchical=False))

IDF by JT.


In [23]:
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

In [24]:
# Training procedure
nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader)

Epoch: 0, val nll=1842.3470208333333
saved!
Epoch: 1, val nll=1801.054625
saved!
Epoch: 2, val nll=1760.0371458333334
saved!
Epoch: 3, val nll=1719.0781666666667
saved!
Epoch: 4, val nll=1677.7449375
saved!
Epoch: 5, val nll=1637.5582291666667
saved!
Epoch: 6, val nll=1597.37925
saved!
Epoch: 7, val nll=1558.2645416666667
saved!
Epoch: 8, val nll=1518.6429166666667
saved!
Epoch: 9, val nll=1479.850875
saved!
Epoch: 10, val nll=1441.0365625
saved!
Epoch: 11, val nll=1403.1942708333333
saved!
Epoch: 12, val nll=1364.7166875
saved!
Epoch: 13, val nll=1326.5913333333333
saved!
Epoch: 14, val nll=1290.3303333333333
saved!
Epoch: 15, val nll=1255.7547083333334
saved!
Epoch: 16, val nll=1225.3941458333334
saved!
Epoch: 17, val nll=1187.9784375
saved!
Epoch: 18, val nll=1156.7870833333334
saved!
Epoch: 19, val nll=1124.9948541666668
saved!
Epoch: 20, val nll=1092.7890833333333
saved!
Epoch: 21, val nll=1067.6906458333333
saved!
Epoch: 22, val nll=1041.75621875
saved!
Epoch: 23, val nll=1016.41

In [20]:
with open(result_dir + '/train_loss.txt', "w") as file:
    for item in nll_val:
        file.write(f"{item}\n")

test_loss = evaluation(name=result_dir + '/' + name, test_loader=test_loader)
f = open(result_dir + '/test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_generated(result_dir + '/' + name, test_loader, 28)
plot_curve(result_dir + '/' + name, nll_val)

FINAL LOSS: nll=695.8311625
