<h1>Beta-Variational Autoencoder with Noise-Contrastive Priors

<h2>Import required libraries

In [None]:
# Standard Libraries
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# PyTorch and related libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

<h2>Initialize device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

<h2> Initialize constants

In [None]:
IMAGE_SIZE = 28
FLATTEN_SIZE = IMAGE_SIZE * IMAGE_SIZE
LATENT_SIZE = 100
NUM_SAMPLES = 8

<h2> Load the dataset

In [None]:
# Custom Transform: Flatten and Normalize
class FlattenAndNormalize:
    def __call__(self, image):
        # Flatten
        image = torch.flatten(image)
        # Normalize by dividing with 27 and rounding
        image = torch.round(image / (27/255))
        return image

transform = transforms.Compose([
    transforms.ToTensor(),
    FlattenAndNormalize()
])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)

<h2> Function to sample images from dataset

In [None]:
def display_images(image_list, rows, columns):
    fig, grid = plt.subplots(rows, columns)
    for i in range(rows):
        for j in range(columns):
            grid[i, j].axis('off')
            grid[i, j].imshow(np.reshape(image_list[(i-1)*rows + j], (IMAGE_SIZE, IMAGE_SIZE)))

# Display sample images
random_indices = torch.randint(0, len(train_dataset), (NUM_SAMPLES,))
sample_images = [train_dataset[i][0].squeeze().numpy() for i in random_indices]
display_images(sample_images, 2, 4)

<h2>Defining the Variational Autoencoder class

In [None]:
latent_size = 100

class VariationalAE(nn.Module):
    def __init__(self):
        super(VariationalAE, self).__init__()

        # Define the encoder layers
        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2 * latent_size)  # Outputs both mean and log variance
        )

        # Define the decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784)
        )

    def reparameterize(self, mean, log_variance):
        """
        Using the reparameterization trick to sample from the distribution
        defined by mean and log_variance.
        """
        std_dev = torch.exp(0.5 * log_variance)
        epsilon = torch.randn_like(std_dev)
        return mean + (epsilon * std_dev)

    def forward(self, x):
        encoded = self.encoder(x)
        mean, log_variance = torch.split(encoded, latent_size, dim=1)
        z = self.reparameterize(mean, log_variance)
        return mean, log_variance, self.decoder(z)

<h2>Training the AutoEncoder model

In [None]:
vae_model = VariationalAE().to(device)
model_parameters = list(vae_model.parameters())

# Hyperparameters
learning_rate = 1e-3
epochs = 10
beta = 0.75

# Loss function
reconstruction_loss = nn.MSELoss()

# Optimizer
optimizer = torch.optim.SGD(model_parameters, lr = learning_rate, momentum = 0.7)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.25, verbose=True)

# To keep track of the best model
best_loss = float('inf')
model_save_path = 'vae_model.pth'

# Training loop
for epoch in range(epochs):
    train_loss = 0.0

    # tqdm.notebook progress bar
    for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}"):
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        mu, logvar, recon_data = vae_model(data)

        # Calculate losses
        MSE = reconstruction_loss(recon_data, data)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss = MSE + beta * KLD

        # Backpropagation
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    avg_train_loss = train_loss / len(train_loader.dataset)
    # Update scheduler with the current epoch's loss
    scheduler.step(avg_train_loss)

    # Print epoch results
    print(f"Epoch {epoch + 1}/{epochs}| Loss: {avg_train_loss:.6f}")

    # Save the best model
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        torch.save(vae_model.state_dict(), model_save_path)

<h2> Loading the model for Evaluation

In [None]:
try:
    vae_model = VariationalAE().to(device)
    vae_model.load_state_dict(torch.load("vae_model.pth"))
    vae_model.eval()
    print("Model loaded successfully.")
except FileNotFoundError:
    print("Saved model state not found. Initialized a new model instead.")


<h2>Binary Classifier

In [None]:
class BinaryClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(in_features = LATENT_SIZE, out_features = 40)
        self.layer2 = nn.Linear(in_features = 40, out_features = 30)
        self.layer3 = nn.Linear(in_features = 30, out_features = 20)
        self.layer4 = nn.Linear(in_features = 20, out_features = 10)
        self.layer5 = nn.Linear(in_features = 10, out_features = 1)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.layer1(x)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.relu(x)
        x = self.layer3(x)
        x = F.relu(x)
        x = self.layer4(x)
        x = F.relu(x)
        x = self.layer5(x)
        x = self.activation(x)
        return x

In [None]:
# Run the Encoder and Create Labels
with torch.no_grad():
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        z = vae_model.encoder(data)
        infMean, infLogVariance = torch.split(z, LATENT_SIZE, dim=1)
        break  # Only need one batch for this example

# Creating Latent Space Data for Classifier
qzx = vae_model.reparameterize(infMean, infLogVariance)
pzx = torch.randn((len(data), LATENT_SIZE), device=device)

# Creating Labels for Classifier Training Data
qzxLabel = torch.ones((len(data), 1)).to(device)
pzxLabel = torch.zeros((len(data), 1)).to(device)

# Preparing Data and Labels for Binary Classifier
trainLabels = torch.cat((qzxLabel, pzxLabel)).to(device)
trainData = torch.cat((qzx, pzx)).to(device)

In [None]:
class customDataset(Dataset):
    def __init__(self, data, labels, transform = None):
        self.transform = transform
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        image = self.data[index]
        label = self.labels[index]
        return image, label

In [None]:
trainDataset = customDataset(data = trainData, labels = trainLabels)
binaryTrainLoader = DataLoader(dataset = trainDataset, batch_size = 100, shuffle = True)

In [None]:
bcModel = BinaryClassifier().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(bcModel.parameters(), lr = (1e-2), momentum=0.8)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)

lowestEpochLoss = []
averageEpochLoss = []
epochList = np.arange(0, 25, dtype = int)

encoderOutputs = []
decoderOutputs = []
for epoch in range(25):
    losses = []
    for batchIndex, (batchImage, batchLabels) in enumerate(binaryTrainLoader):
        batchImage = batchImage.to(device)
        batchLabels = batchLabels.to(device)
        predictedOutput = bcModel(batchImage)
        loss = criterion(predictedOutput, batchLabels)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        losses.append(loss.cpu().detach().numpy())

    avg_loss = np.average(losses)
    print("Epoch:", epoch, "| Average loss:", np.round(avg_loss, 2), "| Lowest Loss:", np.round(np.amin(losses), 2))

    # Update scheduler with the current epoch's average loss
    scheduler.step(avg_loss)

    lowestEpochLoss.append(np.amin(losses))
    averageEpochLoss.append(avg_loss)

torch.save(bcModel.state_dict(), "bc.pth")

In [None]:
plt.plot(epochList, lowestEpochLoss, color = 'blue', label = 'Lowest loss')
plt.plot(epochList, averageEpochLoss, color = 'red', label = 'Average loss')
plt.xlabel('Epoch')
plt.ylabel('Loss per epoch')
plt.show()

In [None]:
bcModel.eval()
bcModel(torch.rand((1, LATENT_SIZE), device=device)).item()

In [None]:
bcModel.eval()
randomIndices = random.sample(range(0, len(data)), 1)
bcModel(trainData[randomIndices]).item()

In [None]:
z0 = torch.randn(LATENT_SIZE, device=device, requires_grad=True)
z0 = z0.view(1, -1)  # Reshape to 2D: [1, LATENT_SIZE]

stepSize = 1e-2
normalMean = torch.tensor([0.0], device=device)
normalSTD = torch.tensor([1.0], device=device)

for timeStamp in range(1000):
    dZ = bcModel(z0)
    rZ = dZ / (1 - dZ)

    energyFunction = (-torch.log(rZ) - torch.distributions.Normal(normalMean, normalSTD).log_prob(z0).to(device))

    grad = torch.autograd.grad(energyFunction.mean(), z0, retain_graph=True)[0]
    noise = torch.randn((1, LATENT_SIZE), device=device)

    z0 = z0 - (0.5 * stepSize * grad) + (torch.sqrt(torch.tensor(stepSize)) * noise)

In [None]:
vae_model.eval()
output = vae_model.decoder(z0)
plt.imshow(output.cpu().detach().numpy().reshape(28,28))