In [None]:
# VQ-VAE Image Interpolation (Weighted based on Classifier Predictions) for the Building-to-Parcel Workflow
# Leonard Schrage, l.schrage@northeastern.edu / lschrage@mit.edu, 2024-25

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import json
from pathlib import Path
import re
import warnings

# Parameters with correct paths
vqvae_model_path = "/home/ls/sites/re-blocking/ensemble-model/models/vq-vae_16-80000.pth"
model_outputs = {
    "Boston": "/home/ls/sites/re-blocking/data/results/ma-boston-p2p-200-150-v100/test_latest/images",
    "Charlotte": "/home/ls/sites/re-blocking/data/results/nc-charlotte-200-150-v100/test_latest/images",
    "Manhattan": "/home/ls/sites/re-blocking/data/results/ny-manhattan-p2p-200-150-v100/test_latest/images",
    "Pittsburgh": "/home/ls/sites/re-blocking/data/results/pa-pittsburgh-p2p-500-150-v100/test_latest/images"
}
predictions_file = "/home/ls/sites/re-blocking/ensemble-model/softmax-output/city-predictions/run_20250319_131844/brooklyn/predictions.json"
output_base_dir = "/home/ls/sites/re-blocking/ensemble-model/ensemble-output/vqvae"
output_image_size = (512, 512)
max_image_dimension = 1024
output_format = "png"

# Ensure the output directory exists
os.makedirs(output_base_dir, exist_ok=True)

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def slerp(val, low, high):
    """Spherical linear interpolation for tensors."""
    low_2d = low.view(low.shape[0], -1)
    high_2d = high.view(high.shape[0], -1)
    low_2d_norm = low_2d / torch.norm(low_2d, dim=1, keepdim=True)
    high_2d_norm = high_2d / torch.norm(high_2d, dim=1, keepdim=True)
    omega = torch.acos((low_2d_norm * high_2d_norm).sum(1).clamp(-1, 1))
    so = torch.sin(omega)
    # Account for the case where vectors are parallel (omega=0)
    if so.item() == 0:
        return (1.0-val) * low + val * high
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low_2d + \
          (torch.sin(val * omega) / so).unsqueeze(1) * high_2d
    return res.view(low.shape)

def load_and_preprocess_image(image_path):
    """Load and preprocess an image."""
    # Open the image and convert to RGB
    image = Image.open(image_path).convert("RGB")
    
    # Get the original size
    original_size = image.size
    
    # Calculate the aspect ratio
    aspect_ratio = original_size[0] / original_size[1]
    
    # Determine the new size (ensuring it's divisible by 8)
    if aspect_ratio > 1:
        new_width = min(original_size[0], max_image_dimension)
        new_width = new_width - (new_width % 8)
        new_height = int(new_width / aspect_ratio)
        new_height = new_height - (new_height % 8)
    else:
        new_height = min(original_size[1], max_image_dimension)
        new_height = new_height - (new_height % 8)
        new_width = int(new_height * aspect_ratio)
        new_width = new_width - (new_width % 8)
    
    new_size = (new_width, new_height)
    
    # Create the transform
    transform = Compose([
        Resize(new_size),
        ToTensor(),
        Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    # Apply the transform
    image_tensor = transform(image).unsqueeze(0).to(device).to(torch.float32)
    
    return image_tensor, new_size, original_size

def generate_weighted_image(model, image_latents, weights, output_path):
    """Generate a weighted image based on the specified weights."""
    # Start with the first latent
    weighted_latent = image_latents[0].clone()
    
    # Apply SLERP interpolation with weights
    for i in range(1, len(image_latents)):
        weighted_latent = slerp(weights[i], weighted_latent, image_latents[i])
    
    with torch.no_grad():
        # Decode the weighted latent
        decoded_image = model.decode(weighted_latent)
        # Normalize the decoded image
        decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
        # Convert to CPU and then to numpy array
        decoded_image = decoded_image.cpu().permute(0, 2, 3, 1).numpy()[0]
    
    # Convert to PIL Image
    decoded_image_pil = Image.fromarray((decoded_image * 255).astype(np.uint8))
    # Resize the output image to the desired output size
    decoded_image_pil = decoded_image_pil.resize(output_image_size, Image.LANCZOS)
    
    # Save the weighted interpolated image
    decoded_image_pil.save(output_path)
    
    return output_path

def extract_image_number(filename):
    """Extract the image number from a filename."""
    # Try different patterns
    patterns = [
        r'_(\d+)_',           # Matches _000123_
        r'parcels_(\d+)',      # Matches parcels_000123
        r'(\d+)\.jpg',         # Matches 000123.jpg
        r'(\d+)\.png'          # Matches 000123.png
    ]
    
    for pattern in patterns:
        match = re.search(pattern, filename)
        if match:
            return match.group(1).zfill(6)  # Ensure 6 digits with leading zeros
    
    # If no pattern matches, return None
    print(f"Could not extract number from {filename}")
    return None

def create_mapping(predictions):
    """Create a mapping from image numbers to their weights."""
    mapping = {}
    
    # Print a few examples of the filenames to help debug
    if predictions and len(predictions) > 0:
        print("Example filenames in predictions:")
        for i in range(min(5, len(predictions))):
            print(f"  {Path(predictions[i]['image_path']).name}")
    
    for pred in predictions:
        image_path = pred["image_path"]
        filename = Path(image_path).name
        image_num = extract_image_number(filename)
        
        if image_num:
            mapping[image_num] = {
                "Boston": pred["probabilities"]["Boston"],
                "Charlotte": pred["probabilities"]["Charlotte"],
                "Manhattan": pred["probabilities"]["Manhattan"],
                "Pittsburgh": pred["probabilities"]["Pittsburgh"]
            }
    
    return mapping

# VQ-VAE Model definitions
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.tanh(self.conv2(x))
        return x

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, x):
        # x: [batch_size, embedding_dim, height, width]
        # Flatten the input
        flat_x = x.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
        # Compute distances
        distances = torch.cdist(flat_x, self.embedding.weight)
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1)
        # Quantize
        quantized = self.embedding(encoding_indices).view(x.shape)
        # Losses
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        # Straight-through estimator
        quantized = x + (quantized - x).detach()
        return quantized, loss, encoding_indices

class VQVAE(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=64, num_embeddings=1024, embedding_dim=64, commitment_cost=0.25):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, hidden_channels, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_channels, in_channels)
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

    def encode(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        return quantized

    def decode(self, quantized):
        decoded = self.decoder(quantized)
        return decoded

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

def load_vqvae_model(model_path):
    # Load the VQ-VAE model
    try:
        checkpoint = torch.load(model_path, map_location=device)
        model = VQVAE(
            in_channels=3,
            hidden_channels=64,
            num_embeddings=1024,
            embedding_dim=64,
            commitment_cost=0.25
        )
        model.load_state_dict(checkpoint)
        model = model.to(device)
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

# Main execution
try:
    # Suppress warnings
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    # Load predictions
    with open(predictions_file, 'r') as f:
        predictions = json.load(f)
    
    print(f"Loaded {len(predictions)} image predictions")
    
    # Create mapping from image numbers to weights
    number_to_weights = create_mapping(predictions)
    print(f"Created mapping for {len(number_to_weights)} images")
    
    # Load the pre-trained VQ-VAE model
    with tqdm(total=1, desc="Loading VQ-VAE model") as pbar:
        vqvae_model = load_vqvae_model(vqvae_model_path)
        pbar.update(1)
    
    # Get a list of all image numbers from the predictions mapping
    image_numbers = sorted(list(number_to_weights.keys()))
    print(f"Found {len(image_numbers)} image numbers to process")
    
    # Process each image number
    with tqdm(total=len(image_numbers), desc="Processing images") as pbar:
        for image_num in image_numbers:
            try:
                # Get weights for this image
                image_weights = number_to_weights[image_num]
                
                # Get model outputs for each city
                city_latents = {}
                cities = ["Boston", "Charlotte", "Manhattan", "Pittsburgh"]
                
                for city in cities:
                    try:
                        city_dir = model_outputs[city]
                        city_file = None
                        
                        # Try to find the corresponding file in the city directory
                        for file in os.listdir(city_dir):
                            if f"_{image_num}_fake_B" in file:
                                city_file = os.path.join(city_dir, file)
                                break
                        
                        if city_file:
                            # Load and preprocess the image
                            image_tensor, _, _ = load_and_preprocess_image(city_file)
                            
                            # Encode to latent
                            with torch.no_grad():
                                latent = vqvae_model.encode(image_tensor)
                                city_latents[city] = latent
                        else:
                            print(f"Warning: Could not find model output for {city}, image {image_num}")
                            
                    except Exception as e:
                        print(f"Error processing {city} model output for image {image_num}: {e}")
                
                # Skip if we don't have all model outputs
                if len(city_latents) < len(cities):
                    print(f"Missing some model outputs for image {image_num}, skipping")
                    pbar.update(1)
                    continue
                
                # Prepare latents and weights in consistent order
                latents = []
                weights = []
                
                for city in cities:
                    latents.append(city_latents[city])
                    weights.append(image_weights[city])
                
                # Generate the ensemble image
                output_filename = f"ensemble_vqvae_{image_num}.{output_format}"
                output_path = os.path.join(output_base_dir, output_filename)
                
                generate_weighted_image(vqvae_model, latents, weights, output_path)
                
                # Create visualization of weights
                weights_viz_dir = os.path.join(output_base_dir, "weights_viz")
                os.makedirs(weights_viz_dir, exist_ok=True)
                
                plt.figure(figsize=(8, 5))
                plt.bar(cities, weights, color='slateblue')
                plt.ylim(0, 1.0)
                plt.title(f"Classifier Weights for Image {image_num}")
                plt.ylabel("Weight")
                plt.xticks(rotation=45)
                plt.tight_layout()
                
                weights_path = os.path.join(weights_viz_dir, f"weights_{image_num}.png")
                plt.savefig(weights_path)
                plt.close()
                
            except Exception as e:
                print(f"Error processing image {image_num}: {e}")
            
            pbar.update(1)
    
    print(f"All images processed. Results saved to {output_base_dir}")

except Exception as e:
    print(f"An error occurred: {e}")