In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).parent))

In [None]:
import math
import time

from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from real_nvp.model_creation import create_multiscale_flow
from utils import show_imgs

DATASET_PATH = Path(os.getcwd()).parent / "data"
MODELS_DIR = DATASET_PATH / "saved_models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", DEVICE)

In [None]:
BATCH_SIZE = 512
NUM_WORKERS = 4

MAX_EPOCHS = 300

In [None]:
# Convert images from 0-1 to 0-255 (integers)
def discretize(sample):
    return (sample * 255).to(torch.int32)

# Transformations applied on each image => make them a tensor and discretize
transform = transforms.Compose([transforms.ToTensor(), discretize])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=NUM_WORKERS, persistent_workers=True)
val_loader = data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=NUM_WORKERS, persistent_workers=True)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=NUM_WORKERS, persistent_workers=True)

In [None]:
def train_flow(
    flow_model: pl.LightningModule,
    model_name: str,
    ckeckpoint_dir: Path,
    train_loader: data.DataLoader,
    val_loader: data.DataLoader,
    test_loader: data.DataLoader,
    max_epochs: int = 200,
    check_val_every_n_epoch: int = 5,
    device: torch.device = DEVICE,
):
    model_dir = ckeckpoint_dir / model_name
    model_dir.mkdir(parents=True, exist_ok=True)

    callbacks = [
        ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_bpd"),
        LearningRateMonitor("epoch")
    ]

    # Create a PyTorch Lightning trainer
    trainer = pl.Trainer(
        default_root_dir=str(model_dir), 
        accelerator="gpu" if str(device).startswith("cuda") else "cpu",
        devices=1,
        max_epochs=max_epochs, 
        gradient_clip_val=1.0,
        callbacks=callbacks,
        check_val_every_n_epoch=check_val_every_n_epoch,
    )
    trainer.logger._log_graph = True
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
    
    result = None
    
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = model_dir / f"{model_name}.ckpt"
    if pretrained_filename.exists():
        print("Found pretrained model, loading...")
        ckpt = torch.load(pretrained_filename, map_location=device)
        flow_model.load_state_dict(ckpt['state_dict'])
        result = ckpt.get("result", None)
    else:
        print("Start training", model_name)
        trainer.fit(
            model=flow_model, 
            train_dataloaders=train_loader, 
            val_dataloaders=val_loader
        )
    
    # Test best model on validation and test set if no result has been found
    # Testing can be expensive due to the importance sampling.
    if result is None:
        val_result = trainer.test(model=flow_model, dataloaders=val_loader, verbose=False)
        start_time = time.time()
        test_result = trainer.test(model=flow_model, dataloaders=test_loader, verbose=False)
        duration = time.time() - start_time
        result = {"test": test_result, "val": val_result, "time": duration / len(test_loader) / flow_model.import_samples}
    
    return flow_model, result

In [None]:
multiscale_real_nvp_model = create_multiscale_flow(
    learning_rate=1e-3,
    gamma=0.99,
    step_size=1,
    example_input_array=train_set[0][0].unsqueeze(dim=0)
)

In [None]:
train_flow(
    flow_model=multiscale_real_nvp_model,
    model_name="multiscale_real_nvp_mnist_v1",
    ckeckpoint_dir=MODELS_DIR,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    max_epochs=MAX_EPOCHS,
    device=DEVICE
)