### Train the VAE model

In [1]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
import torch.nn.functional as F
from PIL import Image, ImageFile
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import time
from datetime import datetime
import random
from tqdm.notebook import tqdm  # Import tqdm for notebooks

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Parameters
batch_size = 64
learning_rate = 1e-3
num_epochs = 100
latent_dim = 20           # Size of the latent space
kl_weight = 1.0           # Weight for KL divergence
early_stopping_patience = 10
max_images_per_dir = 20000  # Set the desired number of images per directory

# Dataset directories
dataset_dirs = [
    '../data/ma-boston/parcels',
    '../data/nc-charlotte/parcels',
    '../data/ny-manhattan/parcels',
    '../data/pa-pittsburgh/parcels'
]

# Directories and identifiers
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
identifier = f"VAE_{num_epochs}-epochs_{max_images_per_dir}_dir-samples_{batch_size}-batch_{current_time}"
output_folder = os.path.join('vae-output', identifier)
model_save_path = os.path.join(output_folder, 'vae_model.pth')
training_curves_path = os.path.join(output_folder, 'training_curves.png')
reconstructions_path = os.path.join(output_folder, 'reconstructions.png')
generated_samples_path = os.path.join(output_folder, 'generated_samples.png')

os.makedirs(output_folder, exist_ok=True)

# Transformations
from torchvision.transforms import InterpolationMode

normalize_mean = [0.485, 0.456, 0.406]  # ImageNet mean
normalize_std = [0.229, 0.224, 0.225]   # ImageNet std

transform = transforms.Compose([
    transforms.RandomResizedCrop(64, scale=(0.8, 1.0), ratio=(0.9, 1.1),
                                 interpolation=InterpolationMode.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=normalize_mean, std=normalize_std),
])

# Dataset
class UnlabeledImageDataset(Dataset):
    def __init__(self, directories, transform=None, max_images_per_dir=None):
        self.image_paths = []
        self.transform = transform

        for dir_path in directories:
            images_in_dir = []
            for root, _, files in os.walk(dir_path):
                for file in files:
                    if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        images_in_dir.append(os.path.join(root, file))

            # Shuffle and limit the number of images per directory if specified
            if max_images_per_dir is not None:
                random.shuffle(images_in_dir)
                images_in_dir = images_in_dir[:max_images_per_dir]

            self.image_paths.extend(images_in_dir)

        # Shuffle the entire dataset
        random.shuffle(self.image_paths)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            # Handle exceptions for corrupted images
            # print(f"Error loading image {image_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))
        if self.transform:
            image = self.transform(image)
        return image, 0  # Dummy label

# Create dataset and data loaders
dataset = UnlabeledImageDataset(dataset_dirs, transform=transform, max_images_per_dir=max_images_per_dir)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Device configuration with MPS support
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
else:
    device = torch.device('cpu')
    print("Using CPU")

# VAE Model Definition
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),   # 64x32x32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128x16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),# 256x8x8
            nn.ReLU(),
            nn.Flatten(),
        )
        self.fc_mu = nn.Linear(256 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(256 * 8 * 8, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 128x16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 64x32x32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),    # 3x64x64
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)  # Sample from standard normal
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, 256, 8, 8)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed_x = self.decode(z)
        return reconstructed_x, mu, logvar

# Initialize model, optimizer, and other components
model = VAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

# Initialize logs
train_recon_losses = []
train_kl_losses = []
val_recon_losses = []
val_kl_losses = []
train_total_losses = []
val_total_losses = []

best_val_loss = float('inf')
epochs_without_improvement = 0

# Training loop with tqdm progress bars
for epoch in tqdm(range(num_epochs), desc="Training Progress", unit="epoch"):
    model.train()
    total_train_recon_loss = 0
    total_train_kl_loss = 0
    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
    for images, _ in train_loader_tqdm:
        images = images.to(device)
        optimizer.zero_grad()
        recon_images, mu, logvar = model(images)
        recon_loss = F.mse_loss(recon_images, images, reduction='sum') / images.size(0)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / images.size(0)
        loss = recon_loss + kl_weight * kl_loss
        loss.backward()
        optimizer.step()
        total_train_recon_loss += recon_loss.item()
        total_train_kl_loss += kl_loss.item()

        # Update progress bar
        train_loader_tqdm.set_postfix({'Recon Loss': f"{recon_loss.item():.4f}", 'KL Loss': f"{kl_loss.item():.4f}"})

    avg_train_recon_loss = total_train_recon_loss / len(train_loader)
    avg_train_kl_loss = total_train_kl_loss / len(train_loader)
    train_recon_losses.append(avg_train_recon_loss)
    train_kl_losses.append(avg_train_kl_loss)
    train_total_losses.append(avg_train_recon_loss + kl_weight * avg_train_kl_loss)

    model.eval()
    total_val_recon_loss = 0
    total_val_kl_loss = 0
    val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, _ in val_loader_tqdm:
            images = images.to(device)
            recon_images, mu, logvar = model(images)
            recon_loss = F.mse_loss(recon_images, images, reduction='sum') / images.size(0)
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / images.size(0)
            total_val_recon_loss += recon_loss.item()
            total_val_kl_loss += kl_loss.item()

            # Update progress bar
            val_loader_tqdm.set_postfix({'Recon Loss': f"{recon_loss.item():.4f}", 'KL Loss': f"{kl_loss.item():.4f}"})

    avg_val_recon_loss = total_val_recon_loss / len(val_loader)
    avg_val_kl_loss = total_val_kl_loss / len(val_loader)
    val_recon_losses.append(avg_val_recon_loss)
    val_kl_losses.append(avg_val_kl_loss)
    val_total_losses.append(avg_val_recon_loss + kl_weight * avg_val_kl_loss)

    # Learning rate scheduling
    scheduler.step(val_total_losses[-1])

    # Early stopping
    if val_total_losses[-1] < best_val_loss:
        best_val_loss = val_total_losses[-1]
        epochs_without_improvement = 0
        # Save the best model
        torch.save(model.state_dict(), model_save_path)
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= early_stopping_patience:
            print("Early stopping triggered.")
            break

    # Update epoch progress bar
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_total_losses[-1]:.4f}, Val Loss: {val_total_losses[-1]:.4f}")

# Plotting functions
def plot_loss_curves(train_losses, val_losses, title, ylabel, save_path):
    plt.figure()
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

# Plot training curves
plot_loss_curves(train_recon_losses, val_recon_losses, 'Reconstruction Loss', 'Loss',
                 os.path.join(output_folder, 'reconstruction_loss.png'))
plot_loss_curves(train_kl_losses, val_kl_losses, 'KL Divergence', 'Loss',
                 os.path.join(output_folder, 'kl_divergence.png'))
plot_loss_curves(train_total_losses, val_total_losses, 'Total Loss', 'Loss',
                 os.path.join(output_folder, 'total_loss.png'))

# Visualize reconstructions
def visualize_reconstructions(model, data_loader, num_images=8, save_path=None):
    model.eval()
    images, _ = next(iter(data_loader))
    images = images[:num_images].to(device)
    with torch.no_grad():
        recon_images, _, _ = model(images)
    images = images.cpu()
    recon_images = recon_images.cpu()

    # Unnormalize images
    images = images * torch.tensor(normalize_std).view(1, 3, 1, 1) + torch.tensor(normalize_mean).view(1, 3, 1, 1)
    recon_images = recon_images * torch.tensor(normalize_std).view(1, 3, 1, 1) + torch.tensor(normalize_mean).view(1, 3, 1, 1)
    images = images.clamp(0, 1)
    recon_images = recon_images.clamp(0, 1)

    fig, axs = plt.subplots(2, num_images, figsize=(num_images * 2, 4))
    for i in range(num_images):
        axs[0, i].imshow(np.transpose(images[i].numpy(), (1, 2, 0)))
        axs[0, i].axis('off')
        axs[1, i].imshow(np.transpose(recon_images[i].numpy(), (1, 2, 0)))
        axs[1, i].axis('off')
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

visualize_reconstructions(model, val_loader, save_path=reconstructions_path)

# Visualize generated samples
def visualize_generated_samples(model, num_samples=8, save_path=None):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, model.latent_dim).to(device)
        generated_images = model.decode(z)
    generated_images = generated_images.cpu()

    # Unnormalize images
    generated_images = generated_images * torch.tensor(normalize_std).view(1, 3, 1, 1) + torch.tensor(normalize_mean).view(1, 3, 1, 1)
    generated_images = generated_images.clamp(0, 1)

    fig, axs = plt.subplots(1, num_samples, figsize=(num_samples * 2, 2))
    for i in range(num_samples):
        axs[i].imshow(np.transpose(generated_images[i].numpy(), (1, 2, 0)))
        axs[i].axis('off')
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

visualize_generated_samples(model, save_path=generated_samples_path)

print("Training complete. Model saved to", model_save_path)

Using CUDA


Training Progress:   0%|          | 0/100 [00:00<?, ?epoch/s]

Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 1/100 - Train Loss: 14850.8976, Val Loss: 13984.3053


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 2/100 - Train Loss: 13924.8045, Val Loss: 13290.9484


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 3/100 - Train Loss: 12731.9651, Val Loss: 12658.9420


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 4/100 - Train Loss: 12625.3688, Val Loss: 12625.7879


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 5/100 - Train Loss: 12583.2538, Val Loss: 12509.3848


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 6/100 - Train Loss: 12513.0418, Val Loss: 12537.4859


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 7/100 - Train Loss: 12510.5265, Val Loss: 12508.2812


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 8/100 - Train Loss: 12484.7422, Val Loss: 12498.3672


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 9/100 - Train Loss: 12484.4436, Val Loss: 12486.3580


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 10/100 - Train Loss: 12467.9007, Val Loss: 12446.5106


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 11/100 - Train Loss: 12453.4190, Val Loss: 12427.3338


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 12/100 - Train Loss: 12441.0407, Val Loss: 12424.7475


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 13/100 - Train Loss: 12431.4964, Val Loss: 12516.2280


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 14/100 - Train Loss: 12422.9839, Val Loss: 12459.7141


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 15/100 - Train Loss: 12461.3211, Val Loss: 12401.7030


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 16/100 - Train Loss: 12435.3174, Val Loss: 12428.0609


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 17/100 - Train Loss: 12441.8130, Val Loss: 12424.1620


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 18/100 - Train Loss: 12413.3285, Val Loss: 12466.6200


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 19/100 - Train Loss: 12417.3785, Val Loss: 12454.4266


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 20/100 - Train Loss: 12429.1573, Val Loss: 12450.6889


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 21/100 - Train Loss: 12401.6209, Val Loss: 12433.5460


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 22/100 - Train Loss: 12403.3424, Val Loss: 12352.5149


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 23/100 - Train Loss: 12407.2778, Val Loss: 12381.2003


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 24/100 - Train Loss: 12396.6952, Val Loss: 12386.8962


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 25/100 - Train Loss: 12394.8416, Val Loss: 12406.6871


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 26/100 - Train Loss: 12401.9025, Val Loss: 12410.5926


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 27/100 - Train Loss: 12383.2156, Val Loss: 12378.9654


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 28/100 - Train Loss: 12390.0501, Val Loss: 12373.3354


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 29/100 - Train Loss: 12372.8179, Val Loss: 12407.3656


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 30/100 - Train Loss: 12381.0113, Val Loss: 12382.0971


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch 31/100 - Train Loss: 12380.9196, Val Loss: 12403.1745


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Validation:   0%|          | 0/250 [00:00<?, ?it/s]

Early stopping triggered.
Training complete. Model saved to vae-output/VAE_100-epochs_20000_dir-samples_64-batch_2024-10-06_19-24/vae_model.pth
