Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import Flowers102, OxfordIIITPet
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

import pytorch_lightning as pl

import seaborn as sns

from collections import Counter

from PIL import Image

import math
import matplotlib.pyplot as plt



For reproducible results, we will be using a seed.

In [None]:
torch.manual_seed(42)

### Data acquisition and analysis
We will be using the Flowers102 and OxfordIIITPet dataset, which can be found in the torchvision library. The datasets will be downloaded from code if needed.

In [None]:
flowers_dataset = Flowers102(root="./data", split="train", transform=None, download=True)
pets_dataset = OxfordIIITPet(root="./data", split="trainval", transform=None, download=True)
print(flowers_dataset)
print(pets_dataset)

Let's see the class distribution and the resolution of the pictures in both datasets.


In [None]:
def plot_class_size_distribution(dataset, dataset_name):
    
    classes = [label for _, label in dataset]
    class_counts = Counter(classes)
    class_sizes = list(class_counts.values())

    plt.figure(figsize=(10, 6))
    sns.histplot(class_sizes, bins=30, kde=False) 
    plt.title(f"Class Size Distribution in {dataset_name} Dataset")
    plt.xlabel("Number of Samples per Class")
    plt.ylabel("Number of Classes")
    plt.show()

In [None]:
print("Flowers Dataset Class Distribution")
plot_class_size_distribution(flowers_dataset, "Flowers102")

print("\nPets Dataset Class Distribution")
plot_class_size_distribution(pets_dataset, "Cars Pets")

Based on the previous two diagrams, the flowers102 dataset has no discrepancies in terms of class distribution, and in the pets dataset we only have small discrepancies, which won't be a problem.

In [None]:
def categorize_resolution(width, height):
    if width < 256 and  height < 256:
        return 'Low (<256)'
    elif 256 <= width < 512 and 256 <= height < 512:
        return 'Medium (256-512)'
    elif 512 <= width < 1024 and 512 <= height < 1024:
        return 'High (512-1024)'
    else:
      return 'Very High (>=1024)'


def plot_image_size_distribution(dataset, dataset_name):
    resolution_counts = {
        'Low (<256)': 0,
        'Medium (256-512)': 0,
        'High (512-1024)': 0,
        'Very High (>=1024)': 0
    }
    for img, _ in dataset:
        if isinstance(img, str):
            img = Image.open(img)
        category = categorize_resolution(img.size[0], img.size[1])
        resolution_counts[category] += 1

    categories = list(resolution_counts.keys())
    counts = list(resolution_counts.values())

    plt.figure(figsize=(10, 6))
    plt.bar(categories, counts, color='skyblue')
    plt.title(f'Image Size Distribution of {dataset_name} by Resolution Category')
    plt.xlabel('Resolution Category')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=15)
    plt.grid(axis='y')
    plt.show()

In [None]:
print("\nFlowers102 Dataset Image Size Analysis")
plot_image_size_distribution(flowers_dataset, "Flowers102")

print("\nOxford-IIIT Pet Dataset Image Size Analysis")
plot_image_size_distribution(pets_dataset, "Oxford Pets")

We will be resizing the datasets to 256x256 resolution so they will be fine like this.


### Data preparation



We will randomly split the datasets into 3 different parts, based on the following ratios:
- Train: 70%,
- Validation: 15%,
- Test: 15%,

because of our datasets don't have more than 10k samples. Also because of this, we will be loading the datasets to memory in one step rather than with a generator or streaming.

In [None]:
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    dataset_size = len(dataset)
    train_size = int(train_ratio * dataset_size)
    val_size = int(val_ratio * dataset_size)
    test_size = dataset_size - train_size - val_size

    return random_split(dataset, [train_size, val_size, test_size])

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])


In [None]:
flowers_dataset = Flowers102(root="./data", split="train", transform=train_transform, download=True)
pets_dataset = OxfordIIITPet(root="./data", split="trainval", transform=train_transform, download=True)

In [None]:
flowers_train, flowers_val, flowers_test = split_dataset(flowers_dataset)
pets_train, pets_val, pets_test = split_dataset(pets_dataset)


In [None]:

flowers102_train_loader = DataLoader(flowers_train, batch_size=64, shuffle=True)
flowers102_val_loader = DataLoader(flowers_val, batch_size=64, shuffle=False)
flowers102_test_loader = DataLoader(flowers_test, batch_size=64, shuffle=False)

oxford_pets_train_loader = DataLoader(pets_train, batch_size=64, shuffle=True)
oxford_pets_val_loader = DataLoader(pets_val, batch_size=64, shuffle=False)
oxford_pets_test_loader = DataLoader(pets_test, batch_size=64, shuffle=False)

In [None]:
print(f"Flowers102 dataset: {len(flowers_train)} training, {len(flowers_val)} validation, {len(flowers_test)} test samples")
print(f"Oxford-IIIT Pets dataset: {len(pets_train)} training, {len(pets_val)} validation, {len(pets_test)} test samples")


### Baseline model

We are implementing a simple VAE(Variational Autoencoder) so we can use it as a baseline model.


In [None]:

class VAE(pl.LightningModule):
    def __init__(self, latent_dim=2, learning_rate=1e-3):
        super(VAE, self).__init__()
        self.learning_rate = learning_rate
        self.latent_dim = latent_dim

        # Encoder
        self.enc = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 32, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.Flatten()
        )
        
        output_shape = self._get_output_shape((3, 256, 256))

        self.fc_mu = nn.Linear(output_shape, latent_dim)
        self.fc_logvar = nn.Linear(output_shape, latent_dim)
        
        vmi = output_shape / 32
        vmi = int(math.sqrt(vmi))
        
        self.dec = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, output_shape),
            torch.nn.ReLU(),
            torch.nn.Unflatten(1, (32, vmi, vmi)),
            torch.nn.ConvTranspose2d(32, 32, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(32, 3, 4, 2, 1), 
            torch.nn.Tanh()
        )

    def _get_output_shape(self, shape):
        '''Returns the size of the output tensor from the conv layers.'''
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))
        output_feat = self.enc(input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        enc_output = self.enc(x)
        mu = self.fc_mu(enc_output)
        logvar = self.fc_logvar(enc_output)
        z = self.reparameterize(mu, logvar)
        return self.dec(z), mu, logvar

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_reconstructed, mu, logvar = self(x)
        KL = 0.5 * torch.sum(mu**2 + torch.exp(logvar) - 1 - logvar)
        reconstruction_loss = F.mse_loss(x_reconstructed, x, reduction="sum")
        loss = reconstruction_loss + KL
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_reconstructed, mu, logvar = self(x)
        KL = 0.5 * torch.sum(mu**2 + torch.exp(logvar) - 1 - logvar)
        reconstruction_loss = F.mse_loss(x_reconstructed, x, reduction="sum")
        loss = reconstruction_loss + KL
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, _ = batch
        x_reconstructed, mu, logvar = self(x)
        KL = 0.5 * torch.sum(mu**2 + torch.exp(logvar) - 1 - logvar)
        reconstruction_loss = F.mse_loss(x_reconstructed, x, reduction="sum")
        loss = reconstruction_loss + KL
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


In [None]:
print(torch.__version__)
print(torch.cuda.is_available())

Let's train two seperate models for the two datasets so we can later compare them by predefined metrics.


In [None]:
model_flowers = VAE(latent_dim=64)
model_pets = VAE(latent_dim=64)
flower_trainer = pl.Trainer(max_epochs=10,accelerator='gpu',devices=1)
pets_trainer = pl.Trainer(max_epochs=10,accelerator='gpu',devices=1)
flower_trainer.fit(model_flowers, flowers102_train_loader)
pets_trainer.fit(model_pets,oxford_pets_train_loader)

Let's plot some of the pictures and their reconstructed counterparts from both datasets.

In [None]:

def unnormalize(img):
    return (img + 1) / 2

def plot_reconstructions(model, data_loader, num_images=10):
    model.eval()
    images, _ = next(iter(data_loader))
    images = images[:num_images] 
    
    with torch.no_grad():
        recon_images, _, _ = model(images.to(model.device))
    
    recon_images = recon_images.cpu()

    fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
    for i in range(num_images):
     
        axes[0, i].imshow(unnormalize(images[i]).permute(1, 2, 0).cpu().numpy())
        axes[0, i].set_title("Original")
        axes[0, i].axis("off")
        
        axes[1, i].imshow(unnormalize(recon_images[i]).permute(1, 2, 0).cpu().numpy())
        axes[1, i].set_title("Reconstructed")
        axes[1, i].axis("off")
    
    plt.tight_layout()
    plt.show()



In [None]:
plot_reconstructions(model_flowers, flowers102_val_loader)

In [None]:
plot_reconstructions(model_pets, oxford_pets_val_loader)


### Defining Evaluation Criteria

For evaluating the quality of generated images from the diffusion models, we will use the following metrics:

1. **Fréchet Inception Distance (FID)**: Measures the similarity between generated images and real images by comparing the mean and covariance of features extracted from the Inception network. Lower values are better, indicating closer similarity to real data.

2. **Inception Score (IS)**: Measures the quality of generated images based on their diversity and how "confident" the Inception network is in classifying them into distinct categories. Higher scores indicate better diversity and quality.

We will implement and calculate these metrics after training the diffusion models.


In [None]:
def calculate_fid(real_images_loader, model):
    model = model.to("cuda")
    fid_metric = FrechetInceptionDistance().to("cuda")
    
    
    for real_images, _ in real_images_loader:
        real_images = (real_images * 255).byte() 
        fid_metric.update(real_images.to("cuda"), real=True)
    
    for real_images, _ in real_images_loader:
        real_images = real_images.to("cuda")  
        generated_images, _, _ = model(real_images) 
        generated_images = (generated_images * 255).clamp(0, 255).byte() 
        fid_metric.update(generated_images, real=False)  
    
    return fid_metric.compute()

In [None]:
def calculate_inception_score(real_images_loader, model):
    model = model.to("cuda")
    inception_score = InceptionScore().to("cuda")
   
    for real_images, _ in real_images_loader:
        generated_images, _, _ = model(real_images.to("cuda"))  
        generated_images = (generated_images * 255).clamp(0, 255).byte()
        inception_score.update(generated_images)

    score = inception_score.compute()
    return score[0].item()  

In [None]:

fid_value = calculate_fid(flowers102_val_loader, model_flowers)
print("FID:", fid_value)

inception_score_value = calculate_inception_score(flowers102_val_loader, model_flowers)
print("Inception Score:", inception_score_value)

The calculated metrics give us the clue that the baseline model isn't performing really well. However, this kind of baseline will be perfect for us.

In [None]:

fid_value = calculate_fid(oxford_pets_val_loader, model_pets)
print("FID:", fid_value)

inception_score_value = calculate_inception_score(oxford_pets_val_loader, model_pets)
print("Inception Score:", inception_score_value)


The results a bit better for this dataset, this is most likely due to that larger sample size available in this dataset compared to the other one.