In [1]:
import pathlib, sys

sys.path.insert(0, str(pathlib.Path().resolve().parent))

In [2]:
import os
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam
from flows.realnvp import RealNVP
from torchvision.utils import make_grid

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams["figure.dpi"] = 144

root = os.path.expanduser("~/datasets")
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])

train_data = datasets.MNIST(root=root, download=False, train=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=256, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_shape = (1,32,32)
hidden_dim = None
num_flows = 2
image_size = 32
save_dir = "../figs/realnvp"
plot = True
image_idx = 0

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

model = RealNVP(input_shape, hidden_dim=hidden_dim, num_flows=num_flows)
model.to(device)

n_epochs = 30

# optimizer = SGD(model.parameters(), lr=0.001, momentum=0.8, weight_decay=0, nesterov=True)
optimizer = Adam(model.parameters(), lr=0.01)#, weight_decay=5e-5)

alpha = 0.05

def pixel_transform(x, alpha=0.05, inverse=False):
    if not inverse:
        x = torch.logit(alpha*0.5+(1-alpha)*x)
    else:
        x = (torch.sigmoid(x)-alpha*0.5)/(1-alpha)
        x = x.clip(0,1)
    return x

for e in range(n_epochs):
    with tqdm(train_loader, desc=f"{e+1}/{n_epochs} epochs") as t:
        train_neg_logp = 0
        train_total = 0
        model.train()
        for i, (x,_) in enumerate(t):
            x = pixel_transform(x)
            _, neg_logp = model(x.to(device))
            optimizer.zero_grad()
            neg_logp.backward()
            optimizer.step()
            train_neg_logp += neg_logp.item()*x.size(0)
            train_total += x.size(0)
            t.set_postfix({"train_neg_logp": train_neg_logp/train_total})
            if plot:
                if i % 50 == 0:
#             if i == len(train_loader)-1:
                    model.eval()
                    with torch.no_grad():
                        c, h, w = (
                            4**num_flows,
                            image_size//2**num_flows,
                            image_size//2**num_flows
                        )
                        z = torch.randn((16, c, h, w))
                        x = model.backward(z.to(device))
                    x = pixel_transform(x, inverse=True)
                    x = x.clip(0, 1)
                    plt.figure(figsize=(16, 16))
                    grid = make_grid(x.cpu(),nrow=4)
                    _ = plt.imshow(grid.numpy().transpose(1,2,0))
                    plt.savefig(os.path.join(save_dir, f"realnvp_{image_idx}.png"))
                    plt.close()
                    image_idx += 1

1/30 epochs: 100%|███████████████████████████████████████████| 235/235 [01:07<00:00,  3.47it/s, train_neg_logp=1.25e+5]
2/30 epochs: 100%|███████████████████████████████████████████| 235/235 [01:09<00:00,  3.40it/s, train_neg_logp=3.01e+3]
3/30 epochs: 100%|███████████████████████████████████████████| 235/235 [01:07<00:00,  3.50it/s, train_neg_logp=2.38e+3]
4/30 epochs: 100%|███████████████████████████████████████████| 235/235 [01:05<00:00,  3.61it/s, train_neg_logp=1.89e+3]
5/30 epochs: 100%|███████████████████████████████████████████| 235/235 [01:05<00:00,  3.58it/s, train_neg_logp=1.43e+3]
6/30 epochs: 100%|███████████████████████████████████████████| 235/235 [01:06<00:00,  3.55it/s, train_neg_logp=1.13e+3]
7/30 epochs: 100%|███████████████████████████████████████████████| 235/235 [01:07<00:00,  3.47it/s, train_neg_logp=955]
8/30 epochs: 100%|███████████████████████████████████████████████| 235/235 [01:06<00:00,  3.55it/s, train_neg_logp=751]
9/30 epochs: 100%|██████████████████████