In [1]:
import os

# Check if the notebook is running on Colab
if 'COLAB_GPU' in os.environ:
    # This block will run only in Google Colab
    IN_COLAB = True
    print("Running on Google Colab. Cloning the repository.")
    !git clone https://github.com/pedro15sousa/energy-based-models-compression.git
    %cd energy-based-models-compression/notebooks
else: 
    # This block will run if not in Google Colab
    IN_COLAB = False
    print("Not running on Google Colab. Assuming local environment.")

Not running on Google Colab. Assuming local environment.


In [2]:
import sys
sys.path.append('..')  # This adds the parent directory (main_folder) to the Python path

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
import torch.utils.data as data
import torch.nn.utils.prune as prune
from torch.nn import functional as F

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning
    import pytorch_lightning as pl
# Callbacks
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# Pytorch Summary
try:
    from torchsummary import summary
except ModuleNotFoundError:
    !pip install --quiet torchsummary
    from torchsummary import summary

import numpy as np
import pandas as pd
import json
import copy

## Imports for plotting
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from matplotlib import cm
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

from metrics.classifier import VGG
from metrics.scores import frechet_inception_distance, inception_score
from EBM import DeepEnergyModel
from energy_funcs.cnn import CNNModel
from energy_funcs.resnet import ResNet18
from sampler import Sampler
from callbacks import InceptionScoreCallback, \
    FIDCallback, SamplerCallback, OutlierCallback, \
    GenerateImagesCallback

import shutil
if IN_COLAB:
    from google.colab import files

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models"

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ", device)

pl.seed_everything(43)

  set_matplotlib_formats('svg', 'pdf') # For export
Seed set to 43


Device: cpu
Device:  cpu


43

In [4]:
# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=64, shuffle=True,  drop_last=True,  num_workers=2, pin_memory=True)
test_loader  = data.DataLoader(test_set,  batch_size=128, shuffle=False, drop_last=False, num_workers=2)

In [5]:
if os.path.exists('../saved_models/mnist-classifier-1 (1).pth'):
    # Load the best model
    mnist_classifier = VGG()

    if device == 'cuda':
        mnist_classifier.load_state_dict(torch.load('../saved_models/mnist-classifier-1 (1).pth'))
    else:
        mnist_classifier.load_state_dict(torch.load('../saved_models/mnist-classifier-1 (1).pth', map_location=torch.device('cpu')))

    mnist_classifier.to(device)
    print("Model already exists and loaded.")
    summary(mnist_classifier, input_size=(1, 28, 28))
else:
    print("Classifier not found in saved_models. Please run the classifier notebook first.")

Model already exists and loaded.
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 28, 28]           1,280
              ReLU-2          [-1, 128, 28, 28]               0
            Conv2d-3          [-1, 128, 28, 28]         147,584
              ReLU-4          [-1, 128, 28, 28]               0
            Conv2d-5          [-1, 128, 28, 28]         147,584
              ReLU-6          [-1, 128, 28, 28]               0
         MaxPool2d-7          [-1, 128, 14, 14]               0
            Conv2d-8          [-1, 256, 14, 14]         295,168
              ReLU-9          [-1, 256, 14, 14]               0
           Conv2d-10          [-1, 256, 14, 14]         590,080
             ReLU-11          [-1, 256, 14, 14]               0
           Conv2d-12          [-1, 256, 14, 14]         590,080
             ReLU-13          [-1, 256, 14, 14]               0
      

In [6]:
pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST_resnet18.ckpt")
teacher_model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)
summary(teacher_model, input_size=(1, 28, 28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           3,200
             Swish-2           [-1, 64, 14, 14]               0
         MaxPool2d-3             [-1, 64, 7, 7]               0
            Conv2d-4             [-1, 64, 7, 7]          36,928
             Swish-5             [-1, 64, 7, 7]               0
            Conv2d-6             [-1, 64, 7, 7]          36,928
             Swish-7             [-1, 64, 7, 7]               0
BasicResidualBlock-8             [-1, 64, 7, 7]               0
            Conv2d-9             [-1, 64, 7, 7]          36,928
            Swish-10             [-1, 64, 7, 7]               0
           Conv2d-11             [-1, 64, 7, 7]          36,928
            Swish-12             [-1, 64, 7, 7]               0
BasicResidualBlock-13             [-1, 64, 7, 7]               0
           Conv2d-14            [-1, 

In [17]:
class DistilledDeepEnergyModel(DeepEnergyModel):
    def __init__(self, img_shape, batch_size, soft_loss_weight, hard_loss_weight, 
                temperature=2.0, f=CNNModel, **f_args):
        super().__init__(img_shape=img_shape, batch_size=batch_size, f=f, **f_args)
        self.teacher_energy = nn.Linear(10, 1).to(device)
        self.soft_loss_weight = soft_loss_weight
        self.hard_loss_weight = hard_loss_weight
        self.temperature = temperature
        self.cnn = f(**f_args).to(device)
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)

    # def set_teacher_model(self, teacher_model):
    #     self.teacher_model = teacher_model
    #     self.teacher_model.eval()

    def get_mse_loss(self, student_output, teacher_output):
        """
        Calculate the MSE loss as the KL divergence between the soft targets from the teacher
        and the predictions of the student. The temperature T is used to soften the probability distributions.
        """
        loss = F.mse_loss(student_output, teacher_output)
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = super().training_step(batch, batch_idx)

        real_imgs, _ = batch
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)
        # Predict energy score for all images
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)

        student_energy = self.cnn(inp_imgs)
        with torch.no_grad():
            teacher_energy = self.teacher_energy(teacher_model(inp_imgs)).squeeze()

        # Calculate mse loss and add it to the original loss
        mse_loss = self.get_mse_loss(student_energy, teacher_energy)
        total_loss = self.hard_loss_weight*loss + self.soft_loss_weight*mse_loss

        # Logging
        self.log('distillation_loss', mse_loss)
        self.log('total_loss', total_loss)

        return total_loss

    def validation_step(self, batch, batch_idx):
        loss = super().training_step(batch, batch_idx)

        real_imgs, _ = batch
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)
        # Predict energy score for all images
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)

        student_energy = self.cnn(inp_imgs)
        with torch.no_grad():
            teacher_energy = self.teacher_energy(teacher_model(inp_imgs))

        # Calculate mse loss and add it to the original loss
        mse_loss = self.get_mse_loss(student_energy, teacher_energy)
        total_loss = self.hard_loss_weight*loss + self.soft_loss_weight*mse_loss

        self.log('val_distillation_loss', mse_loss)
        self.log('val_total_loss', total_loss)
        return mse_loss

In [8]:
def save_scores(trainer, default_root_dir):
    is_callback = [cb for cb in trainer.callbacks if isinstance(cb, InceptionScoreCallback)][0]
    epoch_is_scores = is_callback.scores
    is_path = os.path.join(default_root_dir, "epoch_is_scores.json")

    with open(is_path, 'w') as f:
        json.dump(epoch_is_scores, f)

    fid_callback = [cb for cb in trainer.callbacks if isinstance(cb, FIDCallback)][0]
    epoch_fid_scores = fid_callback.scores
    fid_path = os.path.join(default_root_dir, "epoch_fid_scores.json")

    with open(fid_path, 'w') as f:
        json.dump(epoch_fid_scores, f)

In [9]:
DRIVE_PATH = "/content/drive/My Drive/EBM_saved_models/"

In [10]:
def train_model(**kwargs):

    default_root_dir = os.path.join(CHECKPOINT_PATH, f"MNIST/student")
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=default_root_dir,
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=60,
                         gradient_clip_val=0.1,
                         callbacks=[ModelCheckpoint(dirpath=DRIVE_PATH, filename='MNIST_resnet_student', save_top_k=-1, every_n_epochs=1),
                                    GenerateImagesCallback(every_n_epochs=3),
                                    SamplerCallback(every_n_epochs=3),
                                    OutlierCallback(),
                                    LearningRateMonitor("epoch"),
                                    InceptionScoreCallback(mnist_classifier),
                                    FIDCallback(mnist_classifier)
                                    ])

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST_resnet_student.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        student_model = DistilledDeepEnergyModel.load_from_checkpoint(pretrained_filename)
    else:
        print("No pretrained model found. Start training from scratch...")
        pl.seed_everything(42)
        student_model = DistilledDeepEnergyModel(**kwargs)
        student_model.to(device)

    trainer.fit(student_model, train_loader, test_loader)

    student_model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    save_scores(trainer, default_root_dir)
        
    # No testing as we are more interested in other properties
    return student_model

In [11]:
if IN_COLAB:
    %reload_ext tensorboard
    %tensorboard --logdir saved_models/MNIST/lightning_logs

In [18]:
torch.manual_seed(42)
# MNIST dataset images are 28x28 pixels in size and are black and white, so only have one channel
student_model = train_model(
                    img_shape=(1,28,28),
                    batch_size=train_loader.batch_size,
                    soft_loss_weight=0.5,
                    hard_loss_weight=1.0)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Seed set to 42

  | Name           | Type     | Params | In sizes       | Out sizes
-------------------------------------------------------------------------
0 | cnn            | CNNModel | 77.0 K | [1, 1, 28, 28] | [1]      
1 | teacher_energy | Linear   | 11     | ?              | ?        
-------------------------------------------------------------------------
77.0 K    Trainable params
0         Non-trainable params
77.0 K    Total params
0.308     Total estimated model params size (MB)


No pretrained model found. Start training from scratch...



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/pedrosousa/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/pedrosousa/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


IsADirectoryError: [Errno 21] Is a directory: '/Users/pedrosousa/Documents/Cambridge/Principles of ML Systems/energy-based-models-compression/notebooks'