### Train the VAE model

In [2]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import json
import csv
import logging
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Parameters
total_samples = 5000  # start with low for testing, should be on +50k
batch_size = 256
num_epochs = 5
learning_rate = 1e-3
commitment_cost = 0.25
hidden_channels = 128
embedding_dim = 128
num_embeddings = 2048
checkpoint_interval = 50
image_size = (16, 16)  # Resize to a fixed size
normalize_mean = (0.5,)
normalize_std = (0.5,)

# Create identifier with current date and time
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
identifier = f"vq-vae_{batch_size}-batch_{total_samples}-samples_{embedding_dim}-{num_embeddings}-vector_{num_epochs}-epochs_{current_time}"
logger.info(f"Run identifier: {identifier}")

# Directories
output_dir = os.path.join('vae-output', identifier)  # Each run gets its own directory
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Output directory: {output_dir}")

dataset_dirs = [
    '../data/ma-boston/parcels',
    '../data/nc-charlotte/parcels', 
    '../data/ny-manhattan/parcels', 
    '../data/pa-pittsburgh/parcels'  
]

# Define transformations
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std)
])

# Collect all image paths with improved file handling
all_image_paths = []
for dataset_dir in dataset_dirs:
    for root, _, files in os.walk(dataset_dir):
        for file in files:
            # Skip hidden files and system files
            if file.startswith('.') or file.startswith('_'):
                continue
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                all_image_paths.append(os.path.join(root, file))

logger.info(f"Found {len(all_image_paths)} valid image files")

# Randomly sample the images from the collected paths
sampled_image_paths = random.sample(all_image_paths, min(total_samples, len(all_image_paths)))
logger.info(f"Sampled {len(sampled_image_paths)} images for training")

# Custom dataset with improved error handling
class SampledImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = []
        
        # Verify images during initialization
        for img_path in image_paths:
            try:
                with Image.open(img_path) as img:
                    # Verify the image can be loaded and converted to RGB
                    img.convert('RGB')
                self.image_paths.append(img_path)
            except (IOError, OSError, Image.UnidentifiedImageError) as e:
                logger.warning(f"Skipping corrupted or unreadable image {img_path}: {str(e)}")
        
        logger.info(f"Successfully loaded {len(self.image_paths)} valid images")
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            with Image.open(image_path) as img:
                image = img.convert('RGB')
                
            if self.transform:
                image = self.transform(image)
            return image, 0  # Returning 0 as a placeholder label
            
        except Exception as e:
            logger.error(f"Error loading image {image_path}: {str(e)}")
            # Return a blank image instead of failing
            blank_image = Image.new('RGB', image_size, 'black')
            if self.transform:
                blank_image = self.transform(blank_image)
            return blank_image, 0

# Create a dataset and dataloader for the sampled images
sampled_dataset = SampledImageDataset(sampled_image_paths, transform=transform)
dataloader = DataLoader(sampled_dataset, batch_size=batch_size, shuffle=True)

# VQ-VAE Model Definition
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_embeddings, embedding_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.tanh(self.conv2(x))
        return x

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, x):
        flattened = x.view(-1, self.embedding_dim)
        distances = torch.cdist(flattened, self.embedding.weight)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices).view(x.size())

        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).detach()

        return quantized, loss, encoding_indices

class VQVAE(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, hidden_channels, num_embeddings, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_channels, in_channels)
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

# Set device and initialize AMP
if torch.backends.mps.is_available():
    device = torch.device("mps")
    use_amp = False  # MPS doesn't support AMP yet
    amp_device_type = 'cpu'
elif torch.cuda.is_available():
    device = torch.device("cuda")
    use_amp = True
    amp_device_type = 'cuda'
else:
    device = torch.device("cpu")
    use_amp = False
    amp_device_type = 'cpu'

logger.info(f"Using device: {device}")
logger.info(f"AMP enabled: {use_amp}")

# Initialize scaler
scaler = torch.amp.GradScaler(enabled=use_amp) if use_amp else None

# Initialize model, optimizer, and criterion
model = VQVAE(in_channels=3, hidden_channels=hidden_channels, num_embeddings=num_embeddings,
              embedding_dim=embedding_dim, commitment_cost=commitment_cost).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Save training parameters
training_params = {
    "identifier": identifier,
    "total_samples": total_samples,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
    "commitment_cost": commitment_cost,
    "hidden_channels": hidden_channels,
    "embedding_dim": embedding_dim,
    "num_embeddings": num_embeddings,
    "checkpoint_interval": checkpoint_interval,
    "image_size": image_size,
    "normalize_mean": normalize_mean,
    "normalize_std": normalize_std,
    "use_amp": use_amp,
    "device": str(device),
    "dataset_dirs": dataset_dirs
}

params_path = os.path.join(output_dir, f'{identifier}_training_params.json')
with open(params_path, 'w') as f:
    json.dump(training_params, f)
logger.info(f"Training parameters saved to {params_path}")

# Initialize CSV file for logging
log_path = os.path.join(output_dir, f'{identifier}_training_log.csv')
with open(log_path, 'w', newline='') as csvfile:
    fieldnames = ['epoch', 'loss', 'recon_loss', 'vq_loss']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()

# Training Loop
total_iterations = num_epochs * len(dataloader)
progress_bar = tqdm(total=total_iterations, desc=f"Training {identifier}")

try:
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_recon_loss = 0.0
        running_vq_loss = 0.0
        
        for batch_idx, (images, _) in enumerate(dataloader):
            try:
                images = images.to(device)
                
                # Use appropriate autocast based on device
                with torch.amp.autocast(device_type=amp_device_type, enabled=use_amp):
                    reconstructed, vq_loss = model(images)
                    recon_loss = criterion(reconstructed, images)
                    loss = recon_loss + vq_loss

                # Optimization step with or without AMP
                if use_amp and scaler is not None:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()

                optimizer.zero_grad(set_to_none=True)
                
                # Update running losses
                running_loss += loss.item()
                running_recon_loss += recon_loss.item()
                running_vq_loss += vq_loss.item()
                
                progress_bar.update(1)
                progress_bar.set_description(
                    f"Epoch [{epoch + 1}/{num_epochs}] - {identifier}"
                )

            except Exception as e:
                logger.error(f"Error in batch {batch_idx}: {str(e)}")
                continue

        # Calculate average losses
        avg_loss = running_loss / len(dataloader)
        avg_recon_loss = running_recon_loss / len(dataloader)
        avg_vq_loss = running_vq_loss / len(dataloader)
        
        progress_bar.set_postfix({
            'Loss': f'{avg_loss:.4f}',
            'Recon': f'{avg_recon_loss:.4f}',
            'VQ': f'{avg_vq_loss:.4f}'
        })

        # Log the losses
        with open(log_path, 'a', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow({
                'epoch': epoch + 1,
                'loss': avg_loss,
                'recon_loss': avg_recon_loss,
                'vq_loss': avg_vq_loss
            })

        # Save checkpoint
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(output_dir, f'{identifier}_checkpoint_epoch_{epoch + 1}.pth')
            checkpoint_dict = {
                'identifier': identifier,
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'recon_loss': avg_recon_loss,
                'vq_loss': avg_vq_loss
            }
            if use_amp and scaler is not None:
                checkpoint_dict['scaler_state_dict'] = scaler.state_dict()
                
            torch.save(checkpoint_dict, checkpoint_path)
            logger.info(f"Checkpoint saved to {checkpoint_path}")

except KeyboardInterrupt:
    logger.info(f"Training interrupted by user for {identifier}")
except Exception as e:
    logger.error(f"Training error for {identifier}: {str(e)}")
finally:
    progress_bar.close()
    
    # Save the final model
    try:
        model_output_dir = os.path.join(output_dir, 'models')
        os.makedirs(model_output_dir, exist_ok=True)
        model_save_path = os.path.join(model_output_dir, f"{identifier}_final.pth")
        torch.save({
            'identifier': identifier,
            'model_state_dict': model.state_dict(),
            'training_params': training_params,
            'final_loss': avg_loss,
            'final_recon_loss': avg_recon_loss,
            'final_vq_loss': avg_vq_loss
        }, model_save_path)
        logger.info(f"Final model saved to {model_save_path}")
    except Exception as e:
        logger.error(f"Error saving final model for {identifier}: {str(e)}")

logger.info(f"Training completed for {identifier}")

INFO:__main__:Run identifier: vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59
INFO:__main__:Output directory: vae-output/vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59
INFO:__main__:Found 76605 valid image files
INFO:__main__:Sampled 5000 images for training
INFO:__main__:Successfully loaded 5000 valid images
INFO:__main__:Using device: cuda
INFO:__main__:AMP enabled: True
INFO:__main__:Training parameters saved to vae-output/vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59/vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59_training_params.json


Training vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59:   0%|          | 0/100 [00:0…

INFO:__main__:Final model saved to vae-output/vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59/models/vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59_final.pth
INFO:__main__:Training completed for vq-vae_256-batch_5000-samples_128-2048-vector_5-epochs_2024-11-13_22-59
