In [1]:
import os

# Check if the notebook is running on Colab
if 'COLAB_GPU' in os.environ:
    # This block will run only in Google Colab
    IN_COLAB = True
    print("Running on Google Colab. Cloning the repository.")
    !git clone https://github.com/pedro15sousa/energy-based-models-compression.git
    %cd energy-based-models-compression/notebooks
else: 
    # This block will run if not in Google Colab
    IN_COLAB = False
    print("Not running on Google Colab. Assuming local environment.")

Not running on Google Colab. Assuming local environment.


In [2]:
import sys
sys.path.append('..')  # This adds the parent directory (main_folder) to the Python path

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import MNIST
import torch.utils.data as data

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning
    import pytorch_lightning as pl
# Callbacks
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# Pytorch Summary
try:
    from torchsummary import summary
except ModuleNotFoundError:
    !pip install --quiet torchsummary
    from torchsummary import summary

import numpy as np
import pandas as pd
import json
import copy

## Imports for plotting
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from matplotlib import cm
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

from metrics.classifier import VGG
from EBM import DeepEnergyModel
from energy_funcs.resnet import ResNet18
from energy_funcs.cnn import ConvNet
from sampler import Sampler
from energy_funcs.lenet import LeNet
from callbacks import InceptionScoreCallback, \
    FIDCallback, SamplerCallback, OutlierCallback, \
    GenerateImagesCallback

import shutil
if IN_COLAB:
    from google.colab import files

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ", device)

  set_matplotlib_formats('svg', 'pdf') # For export
Seed set to 42


Device:  cpu


In [14]:
# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=64, shuffle=True,  drop_last=True,  num_workers=2, pin_memory=True)
test_loader  = data.DataLoader(test_set,  batch_size=128, shuffle=False, drop_last=False, num_workers=2)

In [15]:
if os.path.exists('../saved_models/mnist-classifier-1 (1).pth'):
    # Load the best model
    mnist_classifier = VGG()

    if device == 'cuda':
        mnist_classifier.load_state_dict(torch.load('../saved_models/mnist-classifier-1 (1).pth'))
    else:
        mnist_classifier.load_state_dict(torch.load('../saved_models/mnist-classifier-1 (1).pth', map_location=torch.device('cpu')))

    mnist_classifier.to(device)
    print("Model already exists and loaded.")
    summary(mnist_classifier, input_size=(1, 28, 28))
else:
    print("Classifier not found in saved_models. Please run the classifier notebook first.")

Model already exists and loaded.
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 28, 28]           1,280
              ReLU-2          [-1, 128, 28, 28]               0
            Conv2d-3          [-1, 128, 28, 28]         147,584
              ReLU-4          [-1, 128, 28, 28]               0
            Conv2d-5          [-1, 128, 28, 28]         147,584
              ReLU-6          [-1, 128, 28, 28]               0
         MaxPool2d-7          [-1, 128, 14, 14]               0
            Conv2d-8          [-1, 256, 14, 14]         295,168
              ReLU-9          [-1, 256, 14, 14]               0
           Conv2d-10          [-1, 256, 14, 14]         590,080
             ReLU-11          [-1, 256, 14, 14]               0
           Conv2d-12          [-1, 256, 14, 14]         590,080
             ReLU-13          [-1, 256, 14, 14]               0
      

In [16]:
scale = 2 / 255  # Example scale
zero_point = 0  # Example zero-point, for uint8, zero_point is usually 0 or 128

class Swish(nn.Module):
    def forward(self, x):
        x = x * torch.sigmoid(x)
        return x 

class CustomSwish(nn.Module):
    def forward(self, x):
        x = x * torch.sigmoid(x)
        x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
        return x 
    
class DeQuantSwish(nn.Module):
    def __init__(self):
        super().__init__()
        self.dequant = torch.quantization.DeQuantStub()
        self.swish = CustomSwish()

    def forward(self, x):
        x = torch.dequantize(x)
        return self.swish(x)

class ModifiedCNNModel(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        # Assuming original_model is your quantized model
        # Copy all convolutional and linear layers
        self.quant = torch.ao.quantization.QuantStub()
        self.conv1 = copy.deepcopy(original_model.cnn_layers[0])
        self.conv2 = copy.deepcopy(original_model.cnn_layers[2])
        self.conv3 = copy.deepcopy(original_model.cnn_layers[4])
        self.conv4 = copy.deepcopy(original_model.cnn_layers[6])
        self.flatten = copy.deepcopy(original_model.cnn_layers[8])
        self.linear1 = copy.deepcopy(original_model.cnn_layers[9])
        self.linear2 = copy.deepcopy(original_model.cnn_layers[11])
        self.dequant = torch.ao.quantization.DeQuantStub()

        self.swish = Swish()

        # Create DeQuantSwish layers
        # self.swish1 = DeQuantSwish()
        # self.swish2 = DeQuantSwish()
        # self.swish3 = DeQuantSwish()
        # self.swish4 = DeQuantSwish()
        # # self.swish5 = Swish()
        # self.swish5 = DeQuantSwish()

    def quant_init_input(self, x):
        scale = 0.1  # Example scale
        zero_point = 0  # Example zero-point, for uint8, zero_point is usually 0 or 128

        # Quantize the input tensor
        x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8)
        return x

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.dequant(x)
        x = self.swish(x)
        x = self.quant(x)
        x = self.conv2(x)
        x = self.dequant(x)
        x = self.swish(x)
        x = self.quant(x)
        x = self.conv3(x)
        x = self.dequant(x)
        x = self.swish(x)
        x = self.quant(x)
        x = self.conv4(x)
        x = self.dequant(x)
        x = self.swish(x)
        x = self.quant(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.dequant(x)
        x = self.swish(x)
        x = self.quant(x)       
        x = self.linear2(x)
        x = x.squeeze(dim=-1)
        return x

In [None]:
class DeepEnergyModelQuant(pl.LightningModule):

    def __init__(self, pre_trained_cnn, img_shape, batch_size, alpha=0.1, lr=1e-4, beta1=0.0):
        super().__init__()
        self.save_hyperparameters()

        self.cnn = ModifiedCNNModel(pre_trained_cnn)
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)
        self.example_input_array = torch.zeros(1, *img_shape)

    def forward(self, x):
        z = self.cnn(x)
        return z

    def configure_optimizers(self):
        # Energy models can have issues with momentum as the loss surfaces changes with its parameters.
        # Hence, we set it to 0 by default.
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) # Exponential decay over epochs
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs
        real_imgs, _ = batch
        small_noise = torch.randn_like(real_imgs) * 0.005
        real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)

        # Obtain samples
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)

        # Predict energy score for all images
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        # Calculate losses
        reg_loss = self.hparams.alpha * (real_out ** 2 + fake_out ** 2).mean()
        cdiv_loss = fake_out.mean() - real_out.mean()
        loss = reg_loss + cdiv_loss

        # Logging
        self.log('loss', loss)
        self.log('loss_regularization', reg_loss)
        self.log('loss_contrastive_divergence', cdiv_loss)
        self.log('metrics_avg_real', real_out.mean())
        self.log('metrics_avg_fake', fake_out.mean())
        return loss

    def validation_step(self, batch, batch_idx):
        # For validating, we calculate the contrastive divergence between purely random images and unseen examples
        # Note that the validation/test step of energy-based models depends on what we are interested in the model
        real_imgs, _ = batch
        fake_imgs = torch.rand_like(real_imgs) * 2 - 1

        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        cdiv = fake_out.mean() - real_out.mean()
        self.log('val_contrastive_divergence', cdiv)
        self.log('val_fake_out', fake_out.mean())
        self.log('val_real_out', real_out.mean())