# Classifier based on Autoencoder pretraining

In [1]:
import torch
import json
import os
import numpy as np
from torch import nn
from torch.utils.data import Dataset
from PIL import Image, ImageFile
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.nn import DataParallel
from tqdm import tqdm
from tabulate import tabulate
from IPython.display import clear_output

ImageFile.LOAD_TRUNCATED_IMAGES = True

Define dataset class and transform func

In [2]:
class VAEDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        im = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            im = self.transform(im)
        return im

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

Define autoencoder model

In [3]:
class Encoder(nn.Module):
    def __init__(self, dropout_rate=0.2):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            # First conv block: input (batch, 3, H, W) -> output (batch, 16, H/2, W/2)
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            
            # Second conv block: output (batch, 16, H/2, W/2) -> output (batch, 32, H/4, W/4)
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate)
        )
    
    def forward(self, x):
        return self.encoder(x)


class Decoder(nn.Module):
    def __init__(self, dropout_rate=0.2):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            # First transpose conv block: input (batch, 32, H/4, W/4) -> output (batch, 16, H/2, W/2)
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            
            # Second transpose conv block: input (batch, 16, H/2, W/2) -> output (batch, 3, H, W)
            nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.Sigmoid()  # Ensures output pixel values are between 0 and 1
        )
    
    def forward(self, x):
        return self.decoder(x)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

Train autoencoder

In [4]:
# load datasets

paths_folder = 'paths/'
train_paths = json.load(open(paths_folder + 'train_paths_unlabelled.json'))
val_paths = json.load(open(paths_folder + 'train_paths_labelled.json'))

train_dataset = VAEDataset(train_paths, transform=transform)
val_dataset = VAEDataset(val_paths, transform=transform)


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# hyperparameters
batch_size = 256
n_epochs = 10
learning_rate = 1e-4
weight_decay = 1e-2

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(0), num_workers=os.cpu_count())
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())

model = Autoencoder().to(device)
model = DataParallel(model)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)


In [None]:
train_losses = []
val_losses = []
epoch_saved = 0
best_val_loss = float('inf')

epoch_metrics = []
headers = ["Epoch", "Train Loss", "Val Loss"]

progress_bar = tqdm(range(n_epochs), desc="Training Progress", unit="epoch", leave=False)
for epoch in progress_bar:
    # Training
    model.train()
    train_loss = 0
    train_acc = 0

    for imgs in train_loader:
        imgs = imgs.to(device)
        imgs_reconstructed = model(imgs)
        
        loss = criterion(imgs_reconstructed, imgs)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for imgs in val_loader:
            imgs = imgs.to(device)
            imgs_reconstructed = model(imgs)

            loss = criterion(imgs_reconstructed, imgs)
            val_loss += loss.item()

    # Save model with best validation loss
    if val_loss > best_val_loss:
        best_val_f1 = val_loss
        torch.save(model.state_dict(), "models/best_vae.pt")
        epoch_saved = epoch + 1

    epoch_metrics.append([
        epoch + 1,
        train_loss,
        val_loss
    ])

    # Update the progress bar description
    progress_bar.set_description(f"Epoch {epoch + 1}/{n_epochs}")
    
    # Clear and display the table
    clear_output()
    tqdm.write(tabulate(epoch_metrics, headers=headers, floatfmt=".4f"))
    
# save last epoch model
torch.save(model.state_dict(), "models/last_vae.pt")

# Save training and validation metrics
np.save("logs/vae_train_losses.npy", np.array(train_losses))
np.save("logs/vae_val_losses.npy", np.array(val_losses))