In [None]:
## Standard libraries
import os
import json
import math
import numpy as np
import random
from Sampler import Sampler
from CNNModel import CNNModel
from LightningMNISTClassifier import LightningMNISTClassifier
import torch
from LightningCIFARClassifier import LitResnet
from models import ResNetModel
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from utils import GaussianBlur
## Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
from matplotlib.colors import to_rgb
import matplotlib
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
torch.set_float32_matmul_precision('high')

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../PCD_CIFAR_paper_network_pre"

# Setting the seed
pl.seed_everything(42)
torch.backends.cudnn.benchmark = True

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
# torch.backends.cudnn.benchmark = True

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
fid = FrechetInceptionDistance(normalize = True).to(device)
IS = InceptionScore(normalize = True).to(device)

import torchvision.transforms.functional as trans_F
plt.rcParams["savefig.bbox"] = 'tight'
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = trans_F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
BATCH_SIZE = 512
NUM_WORKERS = 4

train_transforms = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transforms)

# indices_car = np.where(np.array(trainset.targets) == 1)[0]
# indices_car = indices_car[:2000]
# indices_dog = np.where(np.array(trainset.targets) == 5)[0]

# indices = np.concatenate((indices_car,indices_dog)).squeeze()

# trainset.data = trainset.data[indices]
# targets = np.array(trainset.targets)[indices]
# targets[targets == 1] = 0
# targets[targets == 5] = 1

# trainset.targets = targets.tolist()

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=NUM_WORKERS,pin_memory=True)

# # Loading the test set

testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                       download=True, transform=test_transforms)

# indices = np.where((np.array(testset.targets) == 1) | (np.array(testset.targets) == 5))[0]  

# testset.data = testset.data[indices]
# targets = np.array(testset.targets)[indices]
# targets[targets == 1] = 0
# targets[targets == 5] = 1
# testset.targets = targets.tolist()

testloader = torch.utils.data.DataLoader(testset, batch_size=512,
                                         shuffle=False, num_workers=NUM_WORKERS,pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
class Paper_net(torch.nn.Module):
    def __init__(self,n_c = 3, n_f = 64, l = 0.2):
        super().__init__()
    
        self.network = nn.Sequential(
            nn.Conv2d(n_c,n_f,3,1,1),
            nn.LeakyReLU(l),
            nn.Conv2d(n_f,n_f*2,4,2,1),
            nn.LeakyReLU(l),
            nn.Conv2d(n_f*2,n_f*4,4,2,1),
            nn.LeakyReLU(l),
            nn.Conv2d(n_f*4,n_f*8,4,2,1),
            nn.LeakyReLU(l),
            nn.Conv2d(n_f*8,1,4,1,0))
    
    def forward(self,x): 
        return self.network(x).squeeze()
    
    

class input_args():
    def __init__(self,im_size,filter_dim,norm,spec_norm):
        self.im_size = im_size
        self.filter_dim = filter_dim
        self.norm = norm
        self.spec_norm = spec_norm
        self.cond = False
        self.multiscale = False
        self.self_attn = True
        self.sigmoid = False
        self.square_energy = False

class DeepEnergyModel(pl.LightningModule):
    
    def __init__(self, img_shape, batch_size, alpha= 1e-2, lr=1e-4, steps= 60, step_size= 1, num_batch = 1,
                 noise_level = 1e-2, sample_size = 1024, sampler_batch = 256, resample_std = 1, **CNN_args):

        """
        Pytorch Lightning class for the training and validation

        Inputs:
            img_shape: Shape of the images as tensors (default 1*28*28 for MNIST images)
            batch_size - Batch size for drawing data from the training set
            alpha - constant which controls the regularization term 
            lr - Initial learning rate for ADAM
            steps - Number of ULA steps for every model parameter update
            step_size - ULA step size 
            noise_level - Noise for ULA
            sample_size - The total number of walkers
            num_batch - The total number of mini-batches drawn 
                        and run ULA on in each iteration of parameter update
            sampler_batch - number of walkers in each mini-batch
            resample_std - critical standard deviation for adaptive resampling
        """

        super().__init__()
        self.save_hyperparameters()
#         pretrained_filename = "CIFAR10_checkpoint.ckpt" # load the pre-trained classifer
#         self.classifier = LitResnet.load_from_checkpoint(pretrained_filename)
#         self.classifier.eval()
        self.args = input_args(32,64,True,False)
        
        self.cnn = Paper_net()
        self.sampler = Sampler(self.cnn, img_shape=img_shape,sample_size = self.hparams.sample_size)

        self.langevin_steps = 0
        

        # obtain the images on the full manufactured training set for 
        # estimation of the cross entropy 
        train_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        
        data_loader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
                                         shuffle=True, num_workers=4,pin_memory=True)
        
        for batch_idx, samples in enumerate(data_loader):
              self.data = samples[0].to(device)
        
#         indices_car_reduced = indices_car[:100]
#         indices_dog_reduced = indices_dog[:40]

#         indices_reduced = np.concatenate((indices_car_reduced,indices_dog_reduced)).squeeze()
        
        self.reduced_train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=train_transforms)
        
        random_indices = torch.multinomial(torch.ones((self.reduced_train_set.data.shape[0],)),300,replacement = False).to(device)
        
        
        data_loader = data.DataLoader(trainset, batch_size=trainset.data.shape[0], shuffle=True,  drop_last=False,  num_workers=4, pin_memory=True)
        reduced_data_loader = data.DataLoader(self.reduced_train_set, batch_size=self.reduced_train_set.data.shape[0], shuffle=True,  drop_last=False,  num_workers=4, pin_memory=True)
        
        for batch_idx, samples in enumerate(data_loader):
              self.data = samples[0].to(device)
                
        for batch_idx, samples in enumerate(reduced_data_loader):
              self.reduced_data = samples[0].to(device)
                
        inverse_transform =  transforms.Compose([transforms.Normalize((0,0,0), (2,2,2)),
                                transforms.Normalize((-0.5, -0.5, -0.5), (1,1,1))
                               ])
        
        FID_set = self.reduced_data[random_indices]
        self.data_uint = FID_set
        
        
        fid.update(self.data_uint, real=True)
        
        print("Data is loaded")
        
        with torch.no_grad():
            self.sampler.ce = ((-self.cnn(self.sampler.examples)).exp().mean()).log() + self.cnn(self.data).mean()
            self.sampler.normalization = (-self.cnn(self.sampler.examples)).exp().mean()
            self.sampler.normal_0 = self.sampler.normalization.clone().detach()
            
        # run Langevin dynamics on the walkers without updating their weights
        # to make sure they are samples from the initial distribution ~ exp(-U_\theta)
        self.sampler.examples = self.sampler.pure_generate_samples(self.cnn,self.sampler.examples,
                                                                   steps=2000, 
                                                                   step_size=self.hparams.step_size,
                                                                   noise_level = self.hparams.noise_level)
        

        self.langevin_steps = 0
        
        index = torch.multinomial(torch.ones((self.hparams.sample_size,)),self.hparams.sampler_batch,replacement = False).to(device)
        self.fake_imgs = self.sampler.examples[index]
#         self.fake_imgs_weights = self.sampler.log_weights[index]
        print("Training begins")
        
    def forward(self, x):
        z = self.cnn(x)
        return z

    def configure_optimizers(self):

        # Energy models can have issues with momentum as the loss surfaces changes with its parameters.
        # Hence, we set it to 0 by default.

        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(0.0, 0.999))
        
        initial_lr = 1
        min_lr = 0
        n_iter = 600
        
        lambda1 = lambda epoch: max((initial_lr - epoch*(initial_lr - min_lr)/n_iter),1e-10) # linear decay
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
        
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):

        real_imgs, _ = batch
        
        noise = torch.randn(real_imgs.shape, device=real_imgs.device)
        noise.normal_(0, 0.03)
        
        real_imgs = real_imgs + noise.data
        
        # Obtain samples from the set of walkers
        
        
        self.langevin_steps = self.langevin_steps + self.hparams.steps
        
        # Predict the energy for all images
        real_out = self.cnn(torch.cat([real_imgs], dim=0))
        fake_out = self.cnn(torch.cat([self.fake_imgs], dim=0))
        
        # Calculate losses

#         weights = self.fake_imgs_weights.clone().exp().detach()
        
        reg_loss = self.hparams.alpha * ((real_out ** 2).mean() + (fake_out ** 2).mean())
        cdiv_loss = -fake_out.mean() + real_out.mean()
    
        
        loss = cdiv_loss + reg_loss
        
#         print(loss)
        
        # track the gradients of the parameters

        parameters = [p for p in self.cnn.parameters() if p.grad is not None and p.requires_grad]
        if len(parameters) == 0:
            total_norm = 0.0
        else:
            total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in parameters]))
        
        
        langevin_steps = self.sampler.langevin_steps
        
        # Do ULA and weight updates on the walkers
        self.fake_imgs,self.fake_imgs_weights = self.sampler.sample_new_exmps(steps=self.hparams.steps, 
                                                  step_size=self.hparams.step_size,
                                                  noise_level = self.hparams.noise_level,
                                                  batch_size = self.hparams.sampler_batch,
                                                  num_batch = self.hparams.num_batch)
        
#         with torch.no_grad():
#             self.sampler.ce = self.sampler.normalization.log() + self.cnn(self.data).mean()
#             probability_weights = weights/weights.sum()
#             self.sampler.ess = 1/((probability_weights**2).sum())
            
#         if torch.std(self.sampler.weights/self.sampler.weights.mean()) > self.hparams.resample_std and self.current_epoch > 300:
#             self.sampler.resample_multinomial()
            
        
        # Logging
        self.log('loss', loss)
        self.log('loss_regularization', reg_loss)
        self.log('loss_contrastive_divergence', cdiv_loss)

        self.log('energy_avg_real', real_out.mean())
        self.log('energy_avg_fake', fake_out.mean())

#         self.log('largest_log_weight',self.sampler.log_weights.max() )
#         self.log('smallest_log_weight',self.sampler.log_weights.min() )
#         self.log('mean_log_weight',self.sampler.log_weights.mean() )

        self.log('f_norm',total_norm)
        self.log('langevin_steps',langevin_steps)
#         self.log('weight_std',torch.std(self.sampler.weights/self.sampler.weights.mean()))
#         self.log('ess',self.sampler.ess)
        
#         self.log('cross_entropy',self.sampler.ce)
#         self.log('normalization',self.sampler.normalization)
        return loss

    def validation_step(self, batch, batch_idx):
        
        inverse_transform =  transforms.Compose([transforms.Normalize((0,0,0), (2,2,2)),
                                transforms.Normalize((-0.5, -0.5, -0.5), (1,1,1))
                               ])
        
        example_indices = torch.randint(0,self.hparams.sample_size-1,(300,))
        example = self.sampler.examples[example_indices]

        random_indices = torch.multinomial(torch.ones((self.reduced_train_set.data.shape[0],)),300,replacement = False).to(device)
        
        FID_set = self.reduced_data[random_indices]
        fid.update(FID_set, real=True)
        
        fid.update(example, real=False)
        FID_score = fid.compute()
        IS.update(self.sampler.examples)
        IS_score = IS.compute()
        
        self.log('FID',FID_score)
        self.log("Inception Score_mean",IS_score[0])
        self.log("Inception Score_std",IS_score[1])

In [None]:
class SamplerCallback(pl.Callback):
    
    def __init__(self, num_imgs=128, every_n_epochs=4):
        super().__init__()
        self.num_imgs = num_imgs             # Number of images to plot
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)
        
    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % 10 == 0:
            
            print(trainer.current_epoch)
            torch.set_grad_enabled(True)  # Tracking gradients for sampling necessary
            exmp_imgs = pl_module.sampler.examples
            indices = torch.randint(0,exmp_imgs.shape[0],size = (self.num_imgs,))
            exmp_imgs = exmp_imgs[indices].clone().detach()
            
            all_imgs = pl_module.sampler.examples.clone().detach()
            all_weights = pl_module.sampler.weights.clone().detach()
            save_data = {'images': all_imgs, "weights":all_weights}
            foldername = "PCD_save_new_all_/" # folder for saving all the data
            if not os.path.exists(foldername): # if not exists, create one
                os.makedirs(foldername)
            filename = foldername + "images_resample" + str(trainer.current_epoch) + "_.pt"
            torch.save(save_data,filename)
            grid = torchvision.utils.make_grid(exmp_imgs, nrow=8, normalize=True, value_range=(-1,1))
            trainer.logger.experiment.add_image("sampler", grid, global_step=trainer.current_epoch)
                

In [None]:
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "MNIST"),
                         accelerator='gpu', devices=1,
                         max_epochs=600,
                         log_every_n_steps=10,
                         gradient_clip_val=1000,
                         # profiler="simple",
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor='FID'),
                                    SamplerCallback(every_n_epochs=2),
                                    LearningRateMonitor("epoch")
                                   ])

    pl.seed_everything(42)
    model = DeepEnergyModel(**kwargs)
    trainer.fit(model, trainloader, testloader)
    model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    
    # No testing as we are more interested in other properties

    return model

In [None]:
model = train_model(img_shape=(3,32,32), 
                    batch_size=trainloader.batch_size)