### Train the VAE model

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.notebook import tqdm
import json
import csv
import numpy as np
import logging
import sys
import glob
import re
import shutil
from datetime import datetime
from pathlib import Path
from sklearn.manifold import TSNE

def setup_logging(identifier):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    log_dir = Path('vae-output') / identifier / 'logs'
    log_dir.mkdir(parents=True, exist_ok=True)
    file_handler = logging.FileHandler(log_dir / f'{identifier}.log')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    return logger

# Parameters
total_samples = 40000
batch_size = 16
num_epochs = 100
learning_rate = 1e-3
commitment_cost = 0.25
hidden_channels = 64
embedding_dim = 32
num_embeddings = 1024
checkpoint_interval = 10
image_size = (512, 512)
normalize_mean = (0.5,)
normalize_std = (0.5,)

# Visualization parameters
visualization_interval = 1
reconstruction_interval = 50
max_saved_checkpoints = 3

current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
identifier = f"vq-vae_{batch_size}-batch_{total_samples}-samples_{embedding_dim}-{num_embeddings}-vector_{num_epochs}-epochs_{current_time}"

logger = setup_logging(identifier)
logger.info(f"Run identifier: {identifier}")

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        torch.cuda.empty_cache()
        if hasattr(torch.cuda, 'memory.set_per_process_memory_fraction'):
            torch.cuda.memory.set_per_process_memory_fraction(0.95)
        logger.info(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
        logger.info(f"CUDA memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device("mps")
        torch.mps.empty_cache()
        logger.info("Using MPS (Apple Silicon) device")
    else:
        device = torch.device("cpu")
        torch.set_num_threads(os.cpu_count())
        logger.info(f"Using CPU with {os.cpu_count()} threads")
    return device

device = get_device()

# DataLoader configuration
num_workers = 0 if device.type == 'mps' else min(os.cpu_count(), 8)
pin_memory = device.type == 'cuda'

output_dir = os.path.join('vae-output', identifier)
dataset_dirs = [
    '../data/ma-boston/parcels',
    '../data/nc-charlotte/parcels', 
    '../data/ny-manhattan/parcels', 
    '../data/pa-pittsburgh/parcels'  
]

logger.info(f"Output directory: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
vis_dir = os.path.join(output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)

class LRSchedulerWrapper:
    def __init__(self, scheduler):
        self.scheduler = scheduler
    def step(self, metric=None):
        self.scheduler.step(metric)
        current_lr = self.scheduler.get_last_lr()[0]
        logger.info(f"Learning rate adjusted to: {current_lr:.6f}")
    def get_last_lr(self):
        return self.scheduler.get_last_lr()
    def state_dict(self):
        return self.scheduler.state_dict()

class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_embeddings, embedding_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.tanh(self.conv2(x))
        return x

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, x):
        flattened = x.view(-1, self.embedding_dim)
        distances = torch.cdist(flattened, self.embedding.weight)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices).view(x.size())
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        quantized = x + (quantized - x).detach()
        return quantized, loss, encoding_indices

class VQVAE(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, hidden_channels, num_embeddings, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_channels, in_channels)
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

2024-11-14 11:14:50,073 - INFO - Run identifier: vq-vae_16-batch_60000-samples_32-1024-vector_100-epochs_2024-11-14_11-14
2024-11-14 11:14:50,203 - INFO - Using CUDA device: NVIDIA GeForce RTX 3070 Ti
2024-11-14 11:14:50,204 - INFO - CUDA memory allocated: 0.00 MB
2024-11-14 11:14:50,204 - INFO - Output directory: vae-output/vq-vae_16-batch_60000-samples_32-1024-vector_100-epochs_2024-11-14_11-14


In [2]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std)
])

def load_and_sample_images(dataset_dirs, total_samples):
    all_image_paths = []
    for dataset_dir in dataset_dirs:
        for root, _, files in os.walk(dataset_dir):
            for file in files:
                if not file.startswith('.') and not file.startswith('._') and file.endswith(('.png', '.jpg', '.jpeg')):
                    all_image_paths.append(os.path.join(root, file))
    
    logger.info(f"Found {len(all_image_paths)} valid image files")
    sampled_paths = random.sample(all_image_paths, total_samples)
    logger.info(f"Sampled {len(sampled_paths)} images for training")
    return sampled_paths

class SampledImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        logger.info(f"Initializing dataset with {len(image_paths)} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, 0
        except Exception as e:
            logger.error(f"Error loading image {image_path}: {str(e)}")
            raise

def create_data_loaders(sampled_image_paths, batch_size, transform, device):
    num_train = int(0.9 * len(sampled_image_paths))
    train_paths = sampled_image_paths[:num_train]
    val_paths = sampled_image_paths[num_train:]
    
    logger.info(f"Split dataset into {len(train_paths)} training and {len(val_paths)} validation samples")
    
    train_dataset = SampledImageDataset(train_paths, transform)
    val_dataset = SampledImageDataset(val_paths, transform)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=2 if num_workers > 0 else None,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=2 if num_workers > 0 else None,
        drop_last=True
    )
    
    return train_loader, val_loader

sampled_image_paths = load_and_sample_images(dataset_dirs, total_samples)
train_loader, val_loader = create_data_loaders(sampled_image_paths, batch_size, transform, device)

class TrainingMonitor:
    def __init__(self, output_dir, num_embeddings):
        self.output_dir = output_dir
        self.num_embeddings = num_embeddings
        self.metrics = {
            'train_reconstruction_loss': [],
            'train_vq_loss': [],
            'train_kl_div': [],
            'train_total_loss': [],
            'val_reconstruction_loss': [],
            'val_vq_loss': [],
            'val_kl_div': [],
            'val_total_loss': [],
            'perplexity': [],
            'encoding_usage': []
        }
        
        self.visualization_dir = os.path.join(output_dir, 'visualizations')
        self.reconstruction_dir = os.path.join(self.visualization_dir, 'reconstructions')
        self.metrics_dir = os.path.join(self.visualization_dir, 'metrics')
        self.embedding_dir = os.path.join(self.visualization_dir, 'embeddings')
        
        for directory in [self.visualization_dir, self.reconstruction_dir, 
                         self.metrics_dir, self.embedding_dir]:
            os.makedirs(directory, exist_ok=True)
        
        self.log_path = os.path.join(output_dir, 'detailed_training_log.csv')
        with open(self.log_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['epoch', 'batch', 'train_reconstruction_loss', 'train_vq_loss', 
                           'train_total_loss', 'train_kl_div', 'val_reconstruction_loss',
                           'val_vq_loss', 'val_total_loss', 'val_kl_div', 'perplexity',
                           'encoding_usage'])
        logger.info(f"Training monitor initialized. Logging to {self.log_path}")
    
    def calculate_kl_divergence(self, encoding_indices):
        encoding_hist = torch.histc(encoding_indices.float(), bins=self.num_embeddings, 
                                  min=0, max=self.num_embeddings-1)
        actual_dist = encoding_hist / encoding_hist.sum()
        uniform_dist = torch.ones_like(actual_dist) / self.num_embeddings
        actual_dist = actual_dist + 1e-10
        kl_div = torch.sum(actual_dist * torch.log(actual_dist / uniform_dist))
        return kl_div.item()
    
    def update_metrics(self, recon_loss, vq_loss, encoding_indices, is_training=True):
        total_loss = recon_loss + vq_loss
        kl_div = self.calculate_kl_divergence(encoding_indices)
        
        prefix = 'train_' if is_training else 'val_'
        self.metrics[f'{prefix}reconstruction_loss'].append(recon_loss)
        self.metrics[f'{prefix}vq_loss'].append(vq_loss)
        self.metrics[f'{prefix}total_loss'].append(total_loss)
        self.metrics[f'{prefix}kl_div'].append(kl_div)
        
        if is_training:
            encoding_hist = torch.histc(encoding_indices.float(), bins=self.num_embeddings, 
                                      min=0, max=self.num_embeddings-1)
            prob = encoding_hist / encoding_hist.sum()
            prob = prob[prob > 0]
            perplexity = torch.exp(-torch.sum(prob * torch.log(prob)))
            
            used_encodings = torch.unique(encoding_indices).size(0)
            encoding_usage = used_encodings / self.num_embeddings * 100
            
            self.metrics['perplexity'].append(perplexity.item())
            self.metrics['encoding_usage'].append(encoding_usage)
    
    def should_visualize(self, epoch):
        return epoch % visualization_interval == 0 or epoch == 0

    def should_save_reconstruction(self, batch_idx):
        return batch_idx % reconstruction_interval == 0

    def clean_old_checkpoints(self, output_dir, current_epoch):
        checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint_epoch_*.pth'))
        if len(checkpoints) > max_saved_checkpoints:
            checkpoints.sort(key=lambda x: int(re.search(r'epoch_(\d+)', x).group(1)))
            for checkpoint in checkpoints[:-max_saved_checkpoints]:
                try:
                    os.remove(checkpoint)
                    logger.info(f"Removed old checkpoint: {checkpoint}")
                except Exception as e:
                    logger.warning(f"Failed to remove checkpoint {checkpoint}: {str(e)}")

    def plot_training_curves(self, epoch):
        if not self.should_visualize(epoch):
            return

        plt.figure(figsize=(20, 15))
        
        plt.subplot(2, 2, 1)
        plt.plot(self.metrics['train_reconstruction_loss'], label='Train')
        plt.plot(self.metrics['val_reconstruction_loss'], label='Validation')
        plt.title('Reconstruction Loss')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(2, 2, 2)
        plt.plot(self.metrics['train_total_loss'], label='Train')
        plt.plot(self.metrics['val_total_loss'], label='Validation')
        plt.title('Total Loss')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(2, 2, 3)
        plt.plot(self.metrics['train_kl_div'], label='Train')
        plt.plot(self.metrics['val_kl_div'], label='Validation')
        plt.title('KL Divergence')
        plt.xlabel('Iteration')
        plt.ylabel('KL Divergence')
        plt.legend()
        
        plt.subplot(2, 2, 4)
        plt.plot(self.metrics['perplexity'], label='Perplexity')
        plt.plot(self.metrics['encoding_usage'], label='Codebook Usage %')
        plt.title('Codebook Metrics')
        plt.xlabel('Iteration')
        plt.ylabel('Value')
        plt.legend()
        
        plt.tight_layout()
        save_path = os.path.join(self.metrics_dir, f'training_metrics_epoch_{epoch}.png')
        plt.savefig(save_path)
        plt.close()
        logger.info(f"Saved training curves for epoch {epoch}")

    def visualize_reconstructions(self, original_images, reconstructed_images, epoch, batch_idx):
        if not self.should_save_reconstruction(batch_idx):
            return

        original_images = (original_images + 1) / 2
        reconstructed_images = (reconstructed_images + 1) / 2
        
        n = min(4, original_images.size(0))
        comparison = torch.cat([
            original_images[:n],
            reconstructed_images[:n]
        ])
        
        grid = make_grid(comparison, nrow=n)
        plt.figure(figsize=(15, 5))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f'Original (top) vs Reconstructed (bottom) - Epoch {epoch}')
        save_path = os.path.join(self.reconstruction_dir, f'reconstruction_e{epoch}.png')
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        plt.close()

    def visualize_embedding_space(self, model, epoch):
        if not self.should_visualize(epoch):
            return

        logger.info(f"Generating t-SNE visualization for epoch {epoch}")
        embeddings = model.vq_layer.embedding.weight.detach().cpu().numpy()
        
        tsne = TSNE(n_components=2, random_state=42)
        embeddings_2d = tsne.fit_transform(embeddings)
        
        plt.figure(figsize=(10, 10))
        scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                            c=range(len(embeddings_2d)), cmap='viridis', alpha=0.6)
        plt.colorbar(scatter, label='Embedding Index')
        plt.title(f't-SNE Visualization of Embedding Space - Epoch {epoch}')
        save_path = os.path.join(self.embedding_dir, f'embedding_space_epoch_{epoch}.png')
        plt.savefig(save_path)
        plt.close()

    def save_final_visualizations(self, epoch):
        for dir_name, final_name in [
            (self.metrics_dir, 'final_metrics.png'),
            (self.embedding_dir, 'final_embedding_space.png'),
            (self.reconstruction_dir, 'final_reconstruction.png')
        ]:
            files = sorted(glob.glob(os.path.join(dir_name, '*')))
            if files:
                latest_file = files[-1]
                final_path = os.path.join(self.visualization_dir, final_name)
                shutil.copy2(latest_file, final_path)
                logger.info(f"Saved final visualization: {final_name}")
            
    def log_epoch_metrics(self, epoch, train_metrics, val_metrics):
        logger.info(f"Epoch {epoch} metrics:")
        logger.info(f"Train - Recon: {train_metrics['recon']:.4f}, "
                   f"VQ: {train_metrics['vq']:.4f}, "
                   f"Total: {train_metrics['total']:.4f}, "
                   f"KL: {train_metrics['kl']:.4f}")
        logger.info(f"Val   - Recon: {val_metrics['recon']:.4f}, "
                   f"VQ: {val_metrics['vq']:.4f}, "
                   f"Total: {val_metrics['total']:.4f}, "
                   f"KL: {val_metrics['kl']:.4f}")

2024-11-14 11:14:50,384 - INFO - Found 76605 valid image files
2024-11-14 11:14:50,408 - INFO - Sampled 60000 images for training
2024-11-14 11:14:50,410 - INFO - Split dataset into 54000 training and 6000 validation samples
2024-11-14 11:14:50,411 - INFO - Initializing dataset with 54000 images
2024-11-14 11:14:50,411 - INFO - Initializing dataset with 6000 images


In [3]:
def validate(model, dataloader, criterion, device, monitor):
    model.eval()
    val_recon_loss = 0
    val_vq_loss = 0
    val_total_loss = 0
    val_kl_div = 0
    num_batches = 0
    
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            
            reconstructed, vq_loss = model(images)
            recon_loss = criterion(reconstructed, images)
            total_loss = recon_loss + vq_loss
            
            if isinstance(model, torch.nn.DataParallel):
                encoded = model.module.encoder(images)
                _, _, encoding_indices = model.module.vq_layer(encoded)
            else:
                encoded = model.encoder(images)
                _, _, encoding_indices = model.vq_layer(encoded)
            
            monitor.update_metrics(recon_loss.item(), vq_loss.item(), encoding_indices, is_training=False)
            
            val_recon_loss += recon_loss.item()
            val_vq_loss += vq_loss.item()
            val_total_loss += total_loss.item()
            val_kl_div += monitor.calculate_kl_divergence(encoding_indices)
            num_batches += 1
    
    val_metrics = {
        'recon': val_recon_loss / num_batches,
        'vq': val_vq_loss / num_batches,
        'total': val_total_loss / num_batches,
        'kl': val_kl_div / num_batches
    }
    
    return val_metrics

def train_epoch(epoch, model, train_loader, val_loader, optimizer, criterion, device, monitor, scaler=None):
    model.train()
    train_recon_loss = 0
    train_vq_loss = 0
    train_total_loss = 0
    train_kl_div = 0
    num_batches = 0
    
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    elif device.type == 'mps':
        torch.mps.empty_cache()
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    progress_bar.set_description(f"Epoch [{epoch+1}/{num_epochs}] - {identifier}")
    
    for batch_idx, (images, _) in progress_bar:
        images = images.to(device)
        optimizer.zero_grad(set_to_none=True)
        
        if device.type == 'cuda':
            with torch.amp.autocast(device_type='cuda'):
                reconstructed, vq_loss = model(images)
                recon_loss = criterion(reconstructed, images)
                total_loss = recon_loss + vq_loss
                
                encoded = model.module.encoder(images) if isinstance(model, torch.nn.DataParallel) else model.encoder(images)
                _, _, encoding_indices = model.module.vq_layer(encoded) if isinstance(model, torch.nn.DataParallel) else model.vq_layer(encoded)
            
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            reconstructed, vq_loss = model(images)
            recon_loss = criterion(reconstructed, images)
            total_loss = recon_loss + vq_loss
            
            encoded = model.encoder(images)
            _, _, encoding_indices = model.vq_layer(encoded)
            
            total_loss.backward()
            optimizer.step()
        
        monitor.update_metrics(recon_loss.item(), vq_loss.item(), encoding_indices, is_training=True)
        
        train_recon_loss += recon_loss.item()
        train_vq_loss += vq_loss.item()
        train_total_loss += total_loss.item()
        train_kl_div += monitor.calculate_kl_divergence(encoding_indices)
        num_batches += 1
        
        progress_bar.set_postfix({
            'Loss': f"{total_loss.item():.4f}",
            'Recon': f"{recon_loss.item():.4f}",
            'VQ': f"{vq_loss.item():.4f}"
        })
        
        if monitor.should_save_reconstruction(batch_idx):
            monitor.visualize_reconstructions(images, reconstructed, epoch + 1, batch_idx)
    
    train_metrics = {
        'recon': train_recon_loss / num_batches,
        'vq': train_vq_loss / num_batches,
        'total': train_total_loss / num_batches,
        'kl': train_kl_div / num_batches
    }
    
    val_metrics = validate(model, val_loader, criterion, device, monitor)
    monitor.log_epoch_metrics(epoch + 1, train_metrics, val_metrics)
    
    return train_metrics, val_metrics

def slerp(p0, p1, t):
    if t == 0:
        return p0
    if t == 1:
        return p1
        
    dot = torch.sum(p0 * p1) / (torch.norm(p0) * torch.norm(p1))
    dot = torch.clamp(dot, -1.0, 1.0)
    
    omega = torch.acos(dot)
    
    if omega == 0:
        return (1.0 - t) * p0 + t * p1
        
    so = torch.sin(omega)
    return torch.sin((1.0 - t) * omega) / so * p0 + torch.sin(t * omega) / so * p1

def create_slerp_interpolation(model, output_path, transform, device):
    logger.info("Starting SLERP interpolation generation")
    logger.info(f"Output path: {output_path}")
    
    image_paths = [
        "../data/results/ma-boston_200250_fake_B.png",
        "../data/results/nc-charlotte_200250_fake_B.png",
        "../data/results/ny-manhattan_200250_fake_B.png",
        "../data/results/pa-pittsburgh_200250_fake_B.png"
    ]
    weights = [0.25, 0.35, 0.15, 0.25]
    
    logger.info("City weights:")
    for city, weight in zip(['Boston', 'Charlotte', 'Manhattan', 'Pittsburgh'], weights):
        logger.info(f"  {city}: {weight:.2f}")
    
    assert abs(sum(weights) - 1.0) < 1e-6, "Weights must sum to 1"
    
    images = []
    logger.info("Loading input images:")
    for path in image_paths:
        logger.info(f"  Loading {os.path.basename(path)}")
        try:
            img = Image.open(path).convert('RGB')
            img_tensor = transform(img)
            images.append(img_tensor)
        except Exception as e:
            logger.error(f"Error loading image {path}: {str(e)}")
            raise
    
    model.eval()
    with torch.no_grad():
        latents = []
        for img in images:
            img = img.unsqueeze(0).to(device)
            if isinstance(model, torch.nn.DataParallel):
                encoded = model.module.encoder(img)
                quantized, _, _ = model.module.vq_layer(encoded)
            else:
                encoded = model.encoder(img)
                quantized, _, _ = model.vq_layer(encoded)
            latents.append(quantized.view(1, -1))
        
        weight_pairs = [(0, 1), (2, 3)]
        pair_weights = [weights[0] + weights[1], weights[2] + weights[3]]
        
        pair_results = []
        for (idx1, idx2), pair_weight in zip(weight_pairs, pair_weights):
            w1 = weights[idx1] / pair_weight
            w2 = weights[idx2] / pair_weight
            pair_result = slerp(latents[idx1], latents[idx2], w2)
            pair_results.append((pair_result, pair_weight))
        
        total_weight = sum(w for _, w in pair_results)
        relative_weight = pair_results[1][1] / total_weight
        final_latent = slerp(pair_results[0][0], pair_results[1][0], relative_weight)
        
        original_shape = quantized.shape
        final_latent = final_latent.view(original_shape)
        if isinstance(model, torch.nn.DataParallel):
            final_image = model.module.decoder(final_latent)
        else:
            final_image = model.decoder(final_latent)
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    img = final_image[0].cpu().numpy().transpose(1, 2, 0)
    img = (img + 1) / 2
    img = np.clip(img, 0, 1)
    
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close()
    logger.info(f"Saved interpolated image to {output_path}")
    
    orig_dir = os.path.join(os.path.dirname(output_path), 'originals_slerp')
    os.makedirs(orig_dir, exist_ok=True)
    
    cities = ['boston', 'charlotte', 'manhattan', 'pittsburgh']
    for img_tensor, city, weight in zip(images, cities, weights):
        img = img_tensor.numpy().transpose(1, 2, 0)
        img = (img + 1) / 2
        img = np.clip(img, 0, 1)
        
        plt.figure(figsize=(10, 10))
        plt.imshow(img)
        plt.title(f'{city} (weight: {weight:.2f})')
        plt.axis('off')
        save_path = os.path.join(orig_dir, f'original_{city}.png')
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=300)
        plt.close()
        logger.info(f"Saved original image for {city} to {save_path}")

    logger.info("SLERP interpolation generation completed")

In [4]:
# Initialize model, optimizer, criterion and device-specific optimizations
model = VQVAE(in_channels=3, hidden_channels=hidden_channels, num_embeddings=num_embeddings,
              embedding_dim=embedding_dim, commitment_cost=commitment_cost).to(device)

if device.type == 'cuda':
    model = torch.nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
    scaler = torch.amp.GradScaler()
    model = model.cuda()
    logger.info(f"Model using CUDA with {torch.cuda.device_count()} GPU(s)")
    logger.info("AMP (Automatic Mixed Precision) enabled")
elif device.type == 'mps':
    scaler = None
    logger.info("Model using MPS (Apple Silicon)")
else:
    scaler = None
    logger.info("Model using CPU")

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss().to(device)

base_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=5
)

scheduler = LRSchedulerWrapper(base_scheduler)
monitor = TrainingMonitor(output_dir, num_embeddings)

training_params = {
    "identifier": identifier,
    "total_samples": total_samples,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
    "commitment_cost": commitment_cost,
    "hidden_channels": hidden_channels,
    "embedding_dim": embedding_dim,
    "num_embeddings": num_embeddings,
    "checkpoint_interval": checkpoint_interval,
    "image_size": image_size,
    "device": device.type,
    "num_workers": num_workers,
    "pin_memory": pin_memory
}

params_path = os.path.join(output_dir, f'{identifier}_training_params.json')
with open(params_path, 'w') as f:
    json.dump(training_params, f, indent=4)
logger.info(f"Training parameters saved to {params_path}")

logger.info("Starting training...")
logger.info(f"Training parameters: {training_params}")

try:
    for epoch in range(num_epochs):
        train_metrics, val_metrics = train_epoch(
            epoch, model, train_loader, val_loader, optimizer, criterion, 
            device, monitor, scaler if device.type == 'cuda' else None
        )
        
        scheduler.step(val_metrics['total'])
        
        if device.type == 'cuda':
            memory_allocated = torch.cuda.memory_allocated() / 1024**2
            memory_cached = torch.cuda.memory_reserved() / 1024**2
            logger.info(f"GPU memory allocated: {memory_allocated:.2f} MB")
            logger.info(f"GPU memory cached: {memory_cached:.2f} MB")
            torch.cuda.empty_cache()
        elif device.type == 'mps':
            torch.mps.empty_cache()
        
        if monitor.should_visualize(epoch):
            monitor.plot_training_curves(epoch + 1)
            monitor.visualize_embedding_space(model, epoch + 1)
        
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(output_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            checkpoint_data = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.scheduler.state_dict(),
                'train_metrics': train_metrics,
                'val_metrics': val_metrics,
                'scaler': scaler.state_dict() if device.type == 'cuda' else None
            }
            torch.save(checkpoint_data, checkpoint_path)
            monitor.clean_old_checkpoints(output_dir, epoch + 1)

except Exception as e:
    logger.error(f"Training interrupted: {str(e)}")
    raise

finally:
    monitor.save_final_visualizations(num_epochs)
    final_model_path = os.path.join(output_dir, f'final_model_{identifier}.pth')
    
    try:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.scheduler.state_dict(),
            'final_train_metrics': train_metrics if 'train_metrics' in locals() else None,
            'final_val_metrics': val_metrics if 'val_metrics' in locals() else None,
            'training_params': training_params,
            'scaler': scaler.state_dict() if device.type == 'cuda' else None
        }, final_model_path)
        logger.info(f"Saved final model to {final_model_path}")
    except Exception as e:
        logger.error(f"Error saving final model: {str(e)}")

try:
    logger.info("Generating SLERP interpolation...")
    output_path = os.path.join(output_dir, 'slerp_interpolation.png')
    create_slerp_interpolation(model, output_path, transform, device)
except Exception as e:
    logger.error(f"Error generating SLERP interpolation: {str(e)}")

logger.info("\nTraining Summary:")
logger.info(f"Total epochs completed: {num_epochs}")
logger.info(f"Final learning rate: {optimizer.param_groups[0]['lr']:.6f}")
logger.info(f"Output directory: {output_dir}")
if device.type == 'cuda':
    logger.info(f"Final GPU memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
logger.info("Training completed successfully")

2024-11-14 11:14:50,602 - INFO - Model using CUDA with 1 GPU(s)
2024-11-14 11:14:50,603 - INFO - AMP (Automatic Mixed Precision) enabled
2024-11-14 11:14:50,604 - INFO - Training monitor initialized. Logging to vae-output/vq-vae_16-batch_60000-samples_32-1024-vector_100-epochs_2024-11-14_11-14/detailed_training_log.csv
2024-11-14 11:14:50,605 - INFO - Training parameters saved to vae-output/vq-vae_16-batch_60000-samples_32-1024-vector_100-epochs_2024-11-14_11-14/vq-vae_16-batch_60000-samples_32-1024-vector_100-epochs_2024-11-14_11-14_training_params.json
2024-11-14 11:14:50,606 - INFO - Starting training...
2024-11-14 11:14:50,606 - INFO - Training parameters: {'identifier': 'vq-vae_16-batch_60000-samples_32-1024-vector_100-epochs_2024-11-14_11-14', 'total_samples': 60000, 'batch_size': 16, 'num_epochs': 100, 'learning_rate': 0.001, 'commitment_cost': 0.25, 'hidden_channels': 64, 'embedding_dim': 32, 'num_embeddings': 1024, 'checkpoint_interval': 10, 'image_size': (512, 512), 'device':

  0%|          | 0/3375 [00:00<?, ?it/s]

2024-11-14 11:16:47,753 - ERROR - Error loading image ../data/nc-charlotte/parcels/parcels_19487.jpg: image file is truncated (58 bytes not processed)
2024-11-14 11:16:48,678 - ERROR - Training interrupted: Caught OSError in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_792347/1501792018.py", line 33, in __getitem__
    image = Image.open(image_path).convert('RGB')
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/PIL/Image.py

OSError: Caught OSError in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_792347/1501792018.py", line 33, in __getitem__
    image = Image.open(image_path).convert('RGB')
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/PIL/Image.py", line 995, in convert
    self.load()
  File "/home/ls/sites/re-blocking/.venv/lib/python3.12/site-packages/PIL/ImageFile.py", line 290, in load
    raise OSError(msg)
OSError: image file is truncated (58 bytes not processed)
