In [None]:
## Standard libraries
import os
import json
import math
import numpy as np
import random
from Sampler import Sampler
from CNNModel import CNNModel
from LightningMNISTClassifier import LightningMNISTClassifier

## Imports for plotting
%matplotlib inline
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
torch.set_float32_matmul_precision('high')

# Path to the folder where the datasets are/should be downloaded
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../Jarzynski_EBMs"

# Setting the seed
pl.seed_everything(42)
torch.backends.cudnn.benchmark = True

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
# torch.backends.cudnn.benchmark = True

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset and manufacture a dataset containing only three digits: 2,3,6
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

indices_2 = (train_set.targets == 2) 
indices_2 = indices_2.nonzero()[:5600] # 5600 samples of 2


indices_3 = (train_set.targets == 3)
indices_3 = indices_3.nonzero()[:2800] # 2800 samples of 3

indices_6 = (train_set.targets == 6)
indices_6 = indices_6.nonzero()[:1400] # 1400 samples of 6

indices = torch.cat((indices_2,indices_6,indices_3)).squeeze()

train_set.data, train_set.targets = train_set.data[indices], train_set.targets[indices]

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)
indices = (test_set.targets == 2) | (test_set.targets == 6) | (test_set.targets == 3)
test_set.data, test_set.targets = test_set.data[indices], test_set.targets[indices]


# We define a set of data loaders that we can use for various purposes later.

train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True,  drop_last=False,  num_workers=4, pin_memory=True)
test_loader  = data.DataLoader(test_set,  batch_size=256, shuffle=False, drop_last=True,  num_workers=4, pin_memory=True)

In [None]:
class DeepEnergyModel(pl.LightningModule):
    
    def __init__(self, img_shape, batch_size, alpha= 1e-2, lr=1e-4, steps= 30, step_size= 5e-5, num_batch = 6,
                 noise_level = 1e-4, sample_size = 1024, sampler_batch = 256, resample_std = 1.5, **CNN_args):

        """
        Pytorch Lightning class for the training and validation

        Inputs:
            img_shape: Shape of the images as tensors (default 1*28*28 for MNIST images)
            batch_size - Batch size for drawing data from the training set
            alpha - constant which controls the regularization term 
            lr - Initial learning rate for ADAM
            steps - Number of ULA steps for every model parameter update
            step_size - ULA step size 
            noise_level - Noise for ULA
            sample_size - The total number of walkers
            num_batch - The total number of mini-batches drawn 
                        and run ULA on in each iteration of parameter update
            sampler_batch - number of walkers in each mini-batch
            resample_std - critical standard deviation for adaptive resampling
        """

        super().__init__()
        self.save_hyperparameters()
        pretrained_filename = "MNIST_checkpoint.ckpt" # load the pre-trained classifer
        self.classifier = LightningMNISTClassifier.load_from_checkpoint(pretrained_filename)
        self.classifier.eval()
        self.cnn = CNNModel(**CNN_args)
        self.sampler = Sampler(self.cnn, img_shape=img_shape,sample_size = self.hparams.sample_size)

        # run Langevin dynamics on the walkers without updating their weights
        # to make sure they are samples from the initial distribution ~ exp(-U_\theta)
        self.sampler.examples = self.sampler.pure_generate_samples(self.cnn,self.sampler.examples,
                                                                   steps=20000, 
                                                                   step_size=self.hparams.step_size,
                                                                   noise_level = self.hparams.noise_level)

        self.langevin_steps = 0
        

        # obtain the images on the full manufactured training set for 
        # estimation of the cross entropy 

        transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

        data_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

        indices_2 = (data_set.targets == 2) 
        indices_2 = indices_2.nonzero()[:5600] # 5600 samples of 2


        indices_3 = (data_set.targets == 3)
        indices_3 = indices_3.nonzero()[:2800] # 2800 samples of 3

        indices_6 = (data_set.targets == 6)
        indices_6 = indices_6.nonzero()[:1400] # 1400 samples of 6

        indices = torch.cat((indices_2,indices_6,indices_3)).squeeze()

        data_set.data, _ = data_set.data[indices], data_set.targets[indices]
        
        data_loader = data.DataLoader(data_set, batch_size=data_set.data.shape[0], shuffle=True,  drop_last=False,  num_workers=4, pin_memory=True)
        
        for batch_idx, samples in enumerate(data_loader):
              self.data = samples[0].to(device)
                
        # Initialize the normalization constant and the cross-entropy estimate
        with torch.no_grad():
            self.sampler.ce = ((-self.cnn(self.sampler.examples)).exp().mean()).log() + self.cnn(self.data).mean()
            self.sampler.normalization = (-self.cnn(self.sampler.examples)).exp().mean()
            self.sampler.normal_0 = self.sampler.normalization.clone().detach()
    
    def forward(self, x): # CNN model
        z = self.cnn(x)
        return z

    def configure_optimizers(self):

        # Energy models can have issues with momentum as the loss surfaces changes with its parameters.
        # Hence, we set it to 0 by default.

        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(0.0, 0.999))
        
        initial_lr = 1
        min_lr = 0
        n_iter = 500 # the learning rate arrives at min_lr in n_iter iterations
        
        lambda1 = lambda epoch: max((initial_lr - epoch*(initial_lr - min_lr)/n_iter),1e-10) # linear decay
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
        
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):

        real_imgs, _ = batch
        
        # Obtain samples from the set of walkers
        index = torch.multinomial(torch.ones((self.hparams.sample_size,)),self.hparams.sampler_batch,replacement = False).to(device)
        fake_imgs = self.sampler.examples[index]
        
        self.langevin_steps = self.langevin_steps + self.hparams.steps
        
        real_imgs = torch.cat([real_imgs], dim=0).detach()
        fake_imgs = torch.cat([fake_imgs], dim=0).detach()
        
        # Predict the energy for all images
        real_out = self.cnn(torch.cat([real_imgs], dim=0))
        fake_out = self.cnn(torch.cat([fake_imgs], dim=0))
        
        # Calculate losses

        weights = self.sampler.weights[index].clone().detach()
        
        reg_loss = self.hparams.alpha * ((real_out ** 2).mean() + ((fake_out ** 2)*weights/weights.sum()).sum())
        cdiv_loss = -(fake_out*weights/weights.sum()).sum() + real_out.mean()
        
        loss = cdiv_loss + reg_loss
        
        # track the gradients of the parameters

        parameters = [p for p in self.cnn.parameters() if p.grad is not None and p.requires_grad]
        if len(parameters) == 0:
            total_norm = 0.0
        else:
            total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in parameters]))
        
        
        langevin_steps = self.sampler.langevin_steps
        
        # Do ULA and weight updates on the walkers
        fake_imgs = self.sampler.sample_new_exmps(steps=self.hparams.steps, 
                                                  step_size=self.hparams.step_size,
                                                  noise_level = self.hparams.noise_level,
                                                  batch_size = self.hparams.sampler_batch,
                                                  num_batch = self.hparams.num_batch)
        
        with torch.no_grad():
            self.sampler.ce = self.sampler.normalization.log() + self.cnn(self.data).mean()
            
        if torch.std(self.sampler.weights/self.sampler.weights.mean()) > self.hparams.resample_std and self.current_epoch > 150:
            self.sampler.resample_multinomial()
            
        
        # Logging
        self.log('loss', loss)
        self.log('loss_regularization', reg_loss)
        self.log('loss_contrastive_divergence', cdiv_loss)

        self.log('energy_avg_real', real_out.mean())
        self.log('energy_avg_fake', fake_out.mean())

        self.log('largest_log_weight',self.sampler.log_weights.max() )
        self.log('smallest_log_weight',self.sampler.log_weights.min() )
        self.log('mean_log_weight',self.sampler.log_weights.mean() )

        self.log('f_norm',total_norm)
        self.log('langevin_steps',langevin_steps)
        self.log('weight_std',torch.std(self.sampler.weights/self.sampler.weights.mean()))
        
        self.log('cross_entropy',self.sampler.ce)
        self.log('normalization',self.sampler.normalization)
        return loss

    def validation_step(self, batch, batch_idx):
        
        # For validating, we calculate the relative weights of the modes
        # by passing all the generated samples/walkers through a classification
        # neural network
        
        pred = self.classifier.forward(self.sampler.examples.clone().detach()).data.max(1, keepdim=True)[1].squeeze().to(torch.int)
        class_index_6 = pred == 1
        class_index_2 = pred == 0
        class_index_3 = pred == 2
        
        weight_all = self.sampler.weights.clone().detach().sum()
        weight_2 = self.sampler.weights[class_index_2].clone().detach().sum()/weight_all
        weight_6 = self.sampler.weights[class_index_6].clone().detach().sum()/weight_all
        weight_3 = self.sampler.weights[class_index_3].clone().detach().sum()/weight_all
        
        self.log('weight_2',weight_2)
        self.log('weight_6',weight_6)
        self.log('weight_3',weight_3)

In [None]:
class SamplerCallback(pl.Callback):
    
    def __init__(self, num_imgs=128, every_n_epochs=4):
        super().__init__()
        self.num_imgs = num_imgs             # Number of images to plot
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)
        
    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % 20 == 0:
            
            print(trainer.current_epoch)
            torch.set_grad_enabled(True)  # Tracking gradients for sampling necessary
            exmp_imgs = pl_module.sampler.examples
            indices = torch.randint(0,exmp_imgs.shape[0],size = (self.num_imgs,))
            exmp_imgs = exmp_imgs[indices].clone().detach()

            grid = torchvision.utils.make_grid(exmp_imgs, nrow=8, normalize=True, value_range=(-1,1))
            trainer.logger.experiment.add_image("sampler", grid, global_step=trainer.current_epoch)
                

In [None]:
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "MNIST"),
                         accelerator='gpu', devices=1,
                         max_epochs=600,
                         log_every_n_steps=50,
                         gradient_clip_val=1,
                         # profiler="simple",
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor='cross_entropy'),
                                    SamplerCallback(every_n_epochs=2),
                                    LearningRateMonitor("epoch")
                                   ])

    pl.seed_everything(42)
    model = DeepEnergyModel(**kwargs)
    trainer.fit(model, train_loader, test_loader)
    model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    
    # No testing as we are more interested in other properties

    return model

In [None]:
model = train_model(img_shape=(1,28,28), 
                    batch_size=train_loader.batch_size)