### Train the VAE model

In [2]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
import torch.nn.functional as F
from PIL import Image, ImageFile
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import time
from datetime import datetime
import random
from tqdm.notebook import tqdm  # Import tqdm for notebooks

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Parameters
batch_size = 32
learning_rate = 1e-3
num_epochs = 200
latent_dim = 20           # Size of the latent space
kl_weight = 1.0           # Weight for KL divergence
early_stopping_patience = 10
max_images_per_dir = 25000  # Set the desired number of images per directory

# Dataset directories
dataset_dirs = [
    '../data/ma-boston/parcels',
    '../data/nc-charlotte/parcels',
    '../data/ny-manhattan/parcels',
    '../data/pa-pittsburgh/parcels'
]

# Directories and identifiers
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
identifier = f"VAE_{num_epochs}-epochs_{max_images_per_dir}_dir-samples_{batch_size}-batch_{current_time}"
output_folder = os.path.join('vae-output', identifier)
model_save_path = os.path.join(output_folder, 'vae_model.pth')
training_curves_path = os.path.join(output_folder, 'training_curves.png')
reconstructions_path = os.path.join(output_folder, 'reconstructions.png')
generated_samples_path = os.path.join(output_folder, 'generated_samples.png')

os.makedirs(output_folder, exist_ok=True)

# Transformations
from torchvision.transforms import InterpolationMode

normalize_mean = [0.485, 0.456, 0.406]  # ImageNet mean
normalize_std = [0.229, 0.224, 0.225]   # ImageNet std

transform = transforms.Compose([
    transforms.RandomResizedCrop(64, scale=(0.8, 1.0), ratio=(0.9, 1.1),
                                 interpolation=InterpolationMode.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=normalize_mean, std=normalize_std),
])

# Dataset
class UnlabeledImageDataset(Dataset):
    def __init__(self, directories, transform=None, max_images_per_dir=None):
        self.image_paths = []
        self.transform = transform

        for dir_path in directories:
            images_in_dir = []
            for root, _, files in os.walk(dir_path):
                for file in files:
                    if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        images_in_dir.append(os.path.join(root, file))

            # Shuffle and limit the number of images per directory if specified
            if max_images_per_dir is not None:
                random.shuffle(images_in_dir)
                images_in_dir = images_in_dir[:max_images_per_dir]

            self.image_paths.extend(images_in_dir)

        # Shuffle the entire dataset
        random.shuffle(self.image_paths)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            # Handle exceptions for corrupted images
            # print(f"Error loading image {image_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))
        if self.transform:
            image = self.transform(image)
        return image, 0  # Dummy label

# Create dataset and data loaders
dataset = UnlabeledImageDataset(dataset_dirs, transform=transform, max_images_per_dir=max_images_per_dir)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Device configuration with MPS support
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
else:
    device = torch.device('cpu')
    print("Using CPU")

# VAE Model Definition
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),   # 64x32x32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128x16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),# 256x8x8
            nn.ReLU(),
            nn.Flatten(),
        )
        self.fc_mu = nn.Linear(256 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(256 * 8 * 8, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 128x16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 64x32x32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),    # 3x64x64
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)  # Sample from standard normal
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(-1, 256, 8, 8)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed_x = self.decode(z)
        return reconstructed_x, mu, logvar

# Initialize model, optimizer, and other components
model = VAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

# Initialize logs
train_recon_losses = []
train_kl_losses = []
val_recon_losses = []
val_kl_losses = []
train_total_losses = []
val_total_losses = []

best_val_loss = float('inf')
epochs_without_improvement = 0

# Training loop with tqdm progress bars
for epoch in tqdm(range(num_epochs), desc="Training Progress", unit="epoch"):
    model.train()
    total_train_recon_loss = 0
    total_train_kl_loss = 0
    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
    for images, _ in train_loader_tqdm:
        images = images.to(device)
        optimizer.zero_grad()
        recon_images, mu, logvar = model(images)
        recon_loss = F.mse_loss(recon_images, images, reduction='sum') / images.size(0)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / images.size(0)
        loss = recon_loss + kl_weight * kl_loss
        loss.backward()
        optimizer.step()
        total_train_recon_loss += recon_loss.item()
        total_train_kl_loss += kl_loss.item()

        # Update progress bar
        train_loader_tqdm.set_postfix({'Recon Loss': f"{recon_loss.item():.4f}", 'KL Loss': f"{kl_loss.item():.4f}"})

    avg_train_recon_loss = total_train_recon_loss / len(train_loader)
    avg_train_kl_loss = total_train_kl_loss / len(train_loader)
    train_recon_losses.append(avg_train_recon_loss)
    train_kl_losses.append(avg_train_kl_loss)
    train_total_losses.append(avg_train_recon_loss + kl_weight * avg_train_kl_loss)

    model.eval()
    total_val_recon_loss = 0
    total_val_kl_loss = 0
    val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, _ in val_loader_tqdm:
            images = images.to(device)
            recon_images, mu, logvar = model(images)
            recon_loss = F.mse_loss(recon_images, images, reduction='sum') / images.size(0)
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / images.size(0)
            total_val_recon_loss += recon_loss.item()
            total_val_kl_loss += kl_loss.item()

            # Update progress bar
            val_loader_tqdm.set_postfix({'Recon Loss': f"{recon_loss.item():.4f}", 'KL Loss': f"{kl_loss.item():.4f}"})

    avg_val_recon_loss = total_val_recon_loss / len(val_loader)
    avg_val_kl_loss = total_val_kl_loss / len(val_loader)
    val_recon_losses.append(avg_val_recon_loss)
    val_kl_losses.append(avg_val_kl_loss)
    val_total_losses.append(avg_val_recon_loss + kl_weight * avg_val_kl_loss)

    # Learning rate scheduling
    scheduler.step(val_total_losses[-1])

    # Early stopping
    if val_total_losses[-1] < best_val_loss:
        best_val_loss = val_total_losses[-1]
        epochs_without_improvement = 0
        # Save the best model
        torch.save(model.state_dict(), model_save_path)
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= early_stopping_patience:
            print("Early stopping triggered.")
            break

    # Update epoch progress bar
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_total_losses[-1]:.4f}, Val Loss: {val_total_losses[-1]:.4f}")

# Plotting functions
def plot_loss_curves(train_losses, val_losses, title, ylabel, save_path):
    plt.figure()
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

# Plot training curves
plot_loss_curves(train_recon_losses, val_recon_losses, 'Reconstruction Loss', 'Loss',
                 os.path.join(output_folder, 'reconstruction_loss.png'))
plot_loss_curves(train_kl_losses, val_kl_losses, 'KL Divergence', 'Loss',
                 os.path.join(output_folder, 'kl_divergence.png'))
plot_loss_curves(train_total_losses, val_total_losses, 'Total Loss', 'Loss',
                 os.path.join(output_folder, 'total_loss.png'))

# Visualize reconstructions
def visualize_reconstructions(model, data_loader, num_images=8, save_path=None):
    model.eval()
    images, _ = next(iter(data_loader))
    images = images[:num_images].to(device)
    with torch.no_grad():
        recon_images, _, _ = model(images)
    images = images.cpu()
    recon_images = recon_images.cpu()

    # Unnormalize images
    images = images * torch.tensor(normalize_std).view(1, 3, 1, 1) + torch.tensor(normalize_mean).view(1, 3, 1, 1)
    recon_images = recon_images * torch.tensor(normalize_std).view(1, 3, 1, 1) + torch.tensor(normalize_mean).view(1, 3, 1, 1)
    images = images.clamp(0, 1)
    recon_images = recon_images.clamp(0, 1)

    fig, axs = plt.subplots(2, num_images, figsize=(num_images * 2, 4))
    for i in range(num_images):
        axs[0, i].imshow(np.transpose(images[i].numpy(), (1, 2, 0)))
        axs[0, i].axis('off')
        axs[1, i].imshow(np.transpose(recon_images[i].numpy(), (1, 2, 0)))
        axs[1, i].axis('off')
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

visualize_reconstructions(model, val_loader, save_path=reconstructions_path)

# Visualize generated samples
def visualize_generated_samples(model, num_samples=8, save_path=None):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, model.latent_dim).to(device)
        generated_images = model.decode(z)
    generated_images = generated_images.cpu()

    # Unnormalize images
    generated_images = generated_images * torch.tensor(normalize_std).view(1, 3, 1, 1) + torch.tensor(normalize_mean).view(1, 3, 1, 1)
    generated_images = generated_images.clamp(0, 1)

    fig, axs = plt.subplots(1, num_samples, figsize=(num_samples * 2, 2))
    for i in range(num_samples):
        axs[i].imshow(np.transpose(generated_images[i].numpy(), (1, 2, 0)))
        axs[i].axis('off')
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

visualize_generated_samples(model, save_path=generated_samples_path)

print("Training complete. Model saved to", model_save_path)

Using CUDA


Training Progress:   0%|          | 0/200 [00:00<?, ?epoch/s]

Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 1/200 - Train Loss: 12639.0788, Val Loss: 12505.4613


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 2/200 - Train Loss: 12518.9604, Val Loss: 12467.2361


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 3/200 - Train Loss: 12483.5444, Val Loss: 12510.5605


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 4/200 - Train Loss: 12487.5647, Val Loss: 12433.0449


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 5/200 - Train Loss: 12459.7176, Val Loss: 12458.9594


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 6/200 - Train Loss: 12447.5245, Val Loss: 12459.1400


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 7/200 - Train Loss: 12436.2111, Val Loss: 12435.0233


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 8/200 - Train Loss: 12449.0709, Val Loss: 12424.3017


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 9/200 - Train Loss: 12439.3618, Val Loss: 12443.2243


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 10/200 - Train Loss: 12443.8048, Val Loss: 12441.8975


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 11/200 - Train Loss: 12443.6929, Val Loss: 12419.0503


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 12/200 - Train Loss: 12423.4642, Val Loss: 12440.2321


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 13/200 - Train Loss: 12425.4850, Val Loss: 12438.7656


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 14/200 - Train Loss: 12428.3207, Val Loss: 12421.6742


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 15/200 - Train Loss: 12435.4574, Val Loss: 12421.9704


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 16/200 - Train Loss: 12435.0915, Val Loss: 12436.4968


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 17/200 - Train Loss: 12413.8972, Val Loss: 12431.4629


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 18/200 - Train Loss: 12417.5813, Val Loss: 12430.4214


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 19/200 - Train Loss: 12403.3958, Val Loss: 12402.8818


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 20/200 - Train Loss: 12396.4886, Val Loss: 12360.3048


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 21/200 - Train Loss: 12405.5269, Val Loss: 12421.9917


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 22/200 - Train Loss: 12393.1605, Val Loss: 12368.4676


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 23/200 - Train Loss: 12408.2931, Val Loss: 12358.1458


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 24/200 - Train Loss: 12402.3910, Val Loss: 12401.6812


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 25/200 - Train Loss: 12383.4676, Val Loss: 12429.7169


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 26/200 - Train Loss: 12382.4238, Val Loss: 12392.4319


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 27/200 - Train Loss: 12386.3815, Val Loss: 12393.3405


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 28/200 - Train Loss: 12401.9754, Val Loss: 12414.5901


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 29/200 - Train Loss: 12400.0700, Val Loss: 12387.6867


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 30/200 - Train Loss: 12371.8495, Val Loss: 12343.1286


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 31/200 - Train Loss: 12362.3646, Val Loss: 12390.9134


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 32/200 - Train Loss: 12363.6098, Val Loss: 12391.5645


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 33/200 - Train Loss: 12388.2936, Val Loss: 12361.8142


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 34/200 - Train Loss: 12375.2483, Val Loss: 12409.0420


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 35/200 - Train Loss: 12368.1662, Val Loss: 12415.5164


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 36/200 - Train Loss: 12356.8875, Val Loss: 12337.0481


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 37/200 - Train Loss: 12365.4532, Val Loss: 12381.1787


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 38/200 - Train Loss: 12372.7314, Val Loss: 12388.4049


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 39/200 - Train Loss: 12373.9469, Val Loss: 12348.6118


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 40/200 - Train Loss: 12363.6325, Val Loss: 12421.0897


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 41/200 - Train Loss: 12378.2883, Val Loss: 12370.0018


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 42/200 - Train Loss: 12379.2967, Val Loss: 12373.2619


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 43/200 - Train Loss: 12353.2371, Val Loss: 12375.5092


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 44/200 - Train Loss: 12355.3517, Val Loss: 12398.4944


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 45/200 - Train Loss: 12364.3617, Val Loss: 12363.1681


Training:   0%|          | 0/2500 [00:00<?, ?it/s]

Validation:   0%|          | 0/625 [00:00<?, ?it/s]

Early stopping triggered.
Training complete. Model saved to vae-output/VAE_200-epochs_25000_dir-samples_32-batch_2024-10-06_20-09/vae_model.pth
