In [None]:
import torch

pth_path = '/home/ls/sites/re-blocking/ensemble-model/models/vq-vae_16-80000.pth'
state_dict = torch.load(pth_path, map_location='cpu')
print(list(state_dict.keys()))

In [2]:
# VQ-VAE Image Interpolation (Weighted based on Classifier Predictions) for the Building-to-Parcel Workflow
# Leonard Schrage, l.schrage@northeastern.edu / lschrage@mit.edu, 2024-25

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import json
from pathlib import Path
import re
import warnings

###############################################################################
#                         CONFIGURABLE PARAMETERS                             #
###############################################################################

# ------------------ Model & Data Paths ------------------ #
vqvae_model_path = "/home/ls/sites/re-blocking/ensemble-model/models/vq-vae_16-80000.pth"
model_outputs = {
    "Boston": "/home/ls/sites/re-blocking/data/results/ma-boston-p2p-200-150-v100/test_latest/images",
    "Charlotte": "/home/ls/sites/re-blocking/data/results/nc-charlotte-200-150-v100/test_latest/images",
    "Manhattan": "/home/ls/sites/re-blocking/data/results/ny-manhattan-p2p-200-150-v100/test_latest/images",
    "Pittsburgh": "/home/ls/sites/re-blocking/data/results/pa-pittsburgh-p2p-500-150-v100/test_latest/images"
}
predictions_file = "/home/ls/sites/re-blocking/ensemble-model/softmax-output/city-predictions/run_20250319_131844/brooklyn/predictions.json"
output_base_dir = "/home/ls/sites/re-blocking/ensemble-model/ensemble-output/vqvae"
weights_viz_dir = "/home/ls/sites/re-blocking/ensemble-model/ensemble-output/vqvae-weights-visuals"

# ------------------ Output Settings ------------------ #
output_image_size = (512, 512)
max_image_dimension = 1024
output_format = "jpg"
image_quality = 90  # JPEG quality

# ------------------ VQ-VAE Architecture Parameters ------------------ #
in_channels = 3
hidden_channels = 64
num_embeddings = 1024
embedding_dim = 64
commitment_cost = 0.25

###############################################################################
#                          SETUP & UTILITY FUNCTIONS                          #
###############################################################################

# Ensure the output directories exist
os.makedirs(output_base_dir, exist_ok=True)
os.makedirs(weights_viz_dir, exist_ok=True)

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def slerp(val, low, high):
    """Improved spherical linear interpolation for tensors with better handling of edge cases."""
    # Reshape tensors to 2D: (batch, -1)
    low_2d = low.view(low.shape[0], -1)
    high_2d = high.view(high.shape[0], -1)
    low_norm = low_2d / torch.norm(low_2d, dim=1, keepdim=True)
    high_norm = high_2d / torch.norm(high_2d, dim=1, keepdim=True)
    
    # Compute cosine similarity and clamp values for safety
    dot_products = (low_norm * high_norm).sum(1)
    dot_products = torch.clamp(dot_products, -1.0, 1.0)
    omega = torch.acos(dot_products)
    so = torch.sin(omega)
    
    # Create a mask for near-parallel vectors
    parallel_mask = (so < 1e-8)
    result = torch.zeros_like(low_2d)
    
    if parallel_mask.any():
        result[parallel_mask] = (1.0 - val) * low_2d[parallel_mask] + val * high_2d[parallel_mask]
    
    if (~parallel_mask).any():
        valid_so = so[~parallel_mask].unsqueeze(1)
        valid_omega = omega[~parallel_mask].unsqueeze(1)
        result[~parallel_mask] = (torch.sin((1.0 - val) * valid_omega) / valid_so) * low_2d[~parallel_mask] + \
                                (torch.sin(val * valid_omega) / valid_so) * high_2d[~parallel_mask]
    
    return result.view(low.shape)

def load_and_preprocess_image(image_path):
    """Load and preprocess an image."""
    image = Image.open(image_path).convert("RGB")
    original_size = image.size
    aspect_ratio = original_size[0] / original_size[1]
    
    if aspect_ratio > 1:
        new_width = min(original_size[0], max_image_dimension)
        new_width = new_width - (new_width % 8)
        new_height = int(new_width / aspect_ratio)
        new_height = new_height - (new_height % 8)
    else:
        new_height = min(original_size[1], max_image_dimension)
        new_height = new_height - (new_height % 8)
        new_width = int(new_height * aspect_ratio)
        new_width = new_width - (new_width % 8)
    
    new_size = (new_width, new_height)
    transform = Compose([
        Resize(new_size),
        ToTensor(),
        Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    image_tensor = transform(image).unsqueeze(0).to(device).to(torch.float32)
    return image_tensor, new_size, original_size

def re_quantize(latent, codebook):
    """
    Re-quantize a continuous latent representation by finding the nearest
    embedding vector from the codebook.
    """
    B, C, H, W = latent.shape
    flat_latent = latent.view(B, C, -1).permute(0, 2, 1)  # shape: (B, N, C)
    flat_latent_expanded = flat_latent.unsqueeze(2)         # shape: (B, N, 1, C)
    codebook_expanded = codebook.unsqueeze(0).unsqueeze(0)    # shape: (1, 1, num_embeddings, C)
    
    distances = torch.sum((flat_latent_expanded - codebook_expanded) ** 2, dim=-1)  # shape: (B, N, num_embeddings)
    encoding_indices = distances.argmin(dim=-1)             # shape: (B, N)
    quantized_flat = codebook[encoding_indices]             # shape: (B, N, C)
    quantized = quantized_flat.permute(0, 2, 1).view(B, C, H, W)
    return quantized

def generate_weighted_image(model, image_latents, weights, output_path):
    """
    Generate an ensemble image by combining continuous latent representations
    using SLERP. The weighted latent is then re-quantized with the model's codebook
    and decoded to produce the final output image.
    """
    total_weight = sum(weights)
    normalized_weights = [w / total_weight for w in weights]
    
    # Start with the first latent representation (continuous, from encoder)
    weighted_latent = image_latents[0].clone()
    # Iteratively blend remaining latents using SLERP with the normalized weights
    for i in range(1, len(image_latents)):
        weighted_latent = slerp(normalized_weights[i], weighted_latent, image_latents[i])
    
    # Re-quantize using the model's codebook
    codebook = model.vq_layer.embedding.weight
    quantized_latent = re_quantize(weighted_latent, codebook)
    
    # Decode the quantized latent
    with torch.no_grad():
        decoded_image = model.decoder(quantized_latent)
        decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
        decoded_image = decoded_image.cpu().permute(0, 2, 3, 1).numpy()[0]
    
    # Convert to PIL image, resize, and save
    decoded_image_pil = Image.fromarray((decoded_image * 255).astype(np.uint8))
    decoded_image_pil = decoded_image_pil.resize(output_image_size, Image.LANCZOS)
    decoded_image_pil.save(output_path, "JPEG", quality=image_quality)
    
    return output_path

def extract_image_number(filename):
    """Extract the image number from a filename."""
    patterns = [
        r'parcels_(\d+)',
        r'_(\d+)_',
        r'(\d+)_fake_B',
        r'(\d+)\.jpg',
        r'(\d+)\.png'
    ]
    for pattern in patterns:
        match = re.search(pattern, filename)
        if match:
            return match.group(1).zfill(6)
    print(f"Could not extract number from {filename}")
    return None

def create_mapping(predictions):
    """Create a mapping from image numbers to their weights."""
    mapping = {}
    if predictions and len(predictions) > 0:
        print("Example filenames in predictions:")
        for i in range(min(5, len(predictions))):
            print(f"  {Path(predictions[i]['image_path']).name}")
    for pred in predictions:
        image_path = pred["image_path"]
        filename = Path(image_path).name
        image_num = extract_image_number(filename)
        if image_num:
            mapping[image_num] = {
                "Boston": pred["probabilities"]["Boston"],
                "Charlotte": pred["probabilities"]["Charlotte"],
                "Manhattan": pred["probabilities"]["Manhattan"],
                "Pittsburgh": pred["probabilities"]["Pittsburgh"]
            }
    return mapping

###############################################################################
#                           VQ-VAE MODEL DEFINITIONS                           #
###############################################################################

class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, out_channels, kernel_size=4, stride=2, padding=1)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.tanh(self.conv2(x))
        return x

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost
    def forward(self, x):
        flat_x = x.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
        distances = torch.cdist(flat_x, self.embedding.weight)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices).view(x.shape)
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        quantized = x + (quantized - x).detach()  # Straight-through estimator
        return quantized, loss, encoding_indices

class VQVAE(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, hidden_channels, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_channels, in_channels)
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
    def encode(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        return quantized
    def decode(self, quantized):
        decoded = self.decoder(quantized)
        return decoded
    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

def load_vqvae_model(model_path):
    """Load the VQ-VAE model from a checkpoint using the top-level parameters."""
    try:
        checkpoint = torch.load(model_path, map_location=device)
        model = VQVAE(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
            commitment_cost=commitment_cost
        )
        model.load_state_dict(checkpoint)
        model = model.to(device)
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

###############################################################################
#                                 MAIN SCRIPT                                  #
###############################################################################

try:
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    with open(predictions_file, 'r') as f:
        predictions = json.load(f)
    print(f"Loaded {len(predictions)} image predictions")
    
    number_to_weights = create_mapping(predictions)
    print(f"Created mapping for {len(number_to_weights)} images")
    
    with tqdm(total=1, desc="Loading VQ-VAE model") as pbar:
        vqvae_model = load_vqvae_model(vqvae_model_path)
        pbar.update(1)
    
    image_numbers = sorted(list(number_to_weights.keys()))
    print(f"Found {len(image_numbers)} image numbers to process")
    
    with tqdm(total=len(image_numbers), desc="Processing images") as pbar:
        for image_num in image_numbers:
            try:
                image_weights = number_to_weights[image_num]
                city_latents = {}
                cities = ["Boston", "Charlotte", "Manhattan", "Pittsburgh"]
                missing_files = False
                
                for city in cities:
                    try:
                        city_dir = model_outputs[city]
                        city_file = None
                        for file in os.listdir(city_dir):
                            if f"_{image_num}_" in file or f"_{image_num}." in file:
                                city_file = os.path.join(city_dir, file)
                                break
                        if city_file:
                            image_tensor, _, _ = load_and_preprocess_image(city_file)
                            # Get the continuous latent from the encoder (before final quantization)
                            with torch.no_grad():
                                latent = vqvae_model.encoder(image_tensor)
                                city_latents[city] = latent
                        else:
                            print(f"Warning: Could not find model output for {city}, image {image_num}")
                            missing_files = True
                    except Exception as e:
                        print(f"Error processing {city} model output for image {image_num}: {e}")
                        missing_files = True
                
                if missing_files or len(city_latents) < len(cities):
                    print(f"Missing some model outputs for image {image_num}, skipping")
                    pbar.update(1)
                    continue
                
                latents = []
                weights = []
                for city in cities:
                    latents.append(city_latents[city])
                    weights.append(image_weights[city])
                
                output_filename = f"ensemble_vqvae_{image_num}.{output_format}"
                output_path = os.path.join(output_base_dir, output_filename)
                generate_weighted_image(vqvae_model, latents, weights, output_path)
                
                # Create and save weight visualization
                plt.figure(figsize=(8, 5))
                plt.bar(cities, weights, color='slateblue')
                plt.ylim(0, 1.0)
                plt.title(f"Classifier Weights for Image {image_num}")
                plt.ylabel("Weight")
                plt.xticks(rotation=45)
                plt.tight_layout()
                weights_path = os.path.join(weights_viz_dir, f"weights_{image_num}.png")
                plt.savefig(weights_path)
                plt.close()
                
            except Exception as e:
                print(f"Error processing image {image_num}: {e}")
            pbar.update(1)
    
    print(f"All images processed. Results saved to {output_base_dir}")
    print(f"Weight visualizations saved to {weights_viz_dir}")

except Exception as e:
    print(f"An error occurred: {e}")

Using device: cuda
Loaded 1000 image predictions
Example filenames in predictions:
  parcels_000252.jpg
  parcels_000556.jpg
  parcels_000804.jpg
  parcels_000474.jpg
  parcels_000729.jpg
Created mapping for 1000 images


Loading VQ-VAE model: 100%|██████████| 1/1 [00:00<00:00, 70.27it/s]


Found 1000 image numbers to process


Processing images:  20%|█▉        | 197/1000 [00:30<01:39,  8.08it/s]

Missing some model outputs for image 000195, skipping


Processing images: 100%|██████████| 1000/1000 [02:35<00:00,  6.44it/s]

All images processed. Results saved to /home/ls/sites/re-blocking/ensemble-model/ensemble-output/vqvae
Weight visualizations saved to /home/ls/sites/re-blocking/ensemble-model/ensemble-output/vqvae-weights-visuals



