In [None]:
# Stable Diffusion Image Interpolation (Weighted based on Classifier Predictions) for the Building-to-Parcel Workflow
# Leonard Schrage, l.schrage@northeastern.edu / lschrage@mit.edu, 2024-25

import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import json
from pathlib import Path
import re
import warnings

# Optional: Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")

# Paths and configuration
model_id = "runwayml/stable-diffusion-v1-5"
model_outputs = {
    "Boston": "/home/ls/sites/re-blocking/data/results/ma-boston-p2p-200-150-v100/test_latest/images",
    "Charlotte": "/home/ls/sites/re-blocking/data/results/nc-charlotte-200-150-v100/test_latest/images",
    "Manhattan": "/home/ls/sites/re-blocking/data/results/ny-manhattan-p2p-200-150-v100/test_latest/images",
    "Pittsburgh": "/home/ls/sites/re-blocking/data/results/pa-pittsburgh-p2p-500-150-v100/test_latest/images"
}
predictions_file = "softmax-output/city-predictions/run_20250319_131844/brooklyn/predictions.json"
output_base_dir = "ensemble-output/stable-diffusion"
inference_steps = 300
output_image_size = (512, 512)
max_image_dimension = 512
output_format = "jpg"  # Options: "png" or "jpg"
image_quality = 90  # For JPG (ignored for PNG)

# Ensure the output directory exists
os.makedirs(output_base_dir, exist_ok=True)

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def load_and_preprocess_image(image_path):
    """Load and preprocess an image."""
    # Open the image and convert to RGB
    image = Image.open(image_path).convert("RGB")
    
    # Get the original size
    original_size = image.size
    
    # Calculate the aspect ratio
    aspect_ratio = original_size[0] / original_size[1]
    
    # Determine the new size (ensuring it's divisible by 8 for the VAE)
    if aspect_ratio > 1:
        new_width = min(original_size[0], max_image_dimension)
        new_width = new_width - (new_width % 8)
        new_height = int(new_width / aspect_ratio)
        new_height = new_height - (new_height % 8)
    else:
        new_height = min(original_size[1], max_image_dimension)
        new_height = new_height - (new_height % 8)
        new_width = int(new_height * aspect_ratio)
        new_width = new_width - (new_width % 8)
    
    new_size = (new_width, new_height)
    
    # Create the transform
    transform = Compose([
        Resize(new_size),
        ToTensor(),
        Normalize([0.5], [0.5])
    ])
    
    # Apply the transform
    image_tensor = transform(image).unsqueeze(0).to(device).to(torch.float16)
    
    return image_tensor, new_size, original_size

def slerp(val, low, high):
    """Spherical linear interpolation for latents."""
    low_2d = low.view(low.shape[0], -1)
    high_2d = high.view(high.shape[0], -1)
    low_2d_norm = low_2d / torch.norm(low_2d, dim=1, keepdim=True)
    high_2d_norm = high_2d / torch.norm(high_2d, dim=1, keepdim=True)
    omega = torch.acos((low_2d_norm * high_2d_norm).sum(1).clamp(-1, 1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low_2d + \
          (torch.sin(val * omega) / so).unsqueeze(1) * high_2d
    return res.view(low.shape)

def generate_weighted_image(pipe, image_latents, weights, output_path):
    """Generate a weighted image based on the specified weights."""
    # Start with the first latent
    weighted_latent = image_latents[0].clone()
    
    # Apply SLERP interpolation with weights
    for i in range(1, len(image_latents)):
        weighted_latent = slerp(weights[i], weighted_latent, image_latents[i])
    
    with torch.no_grad():
        # Scale the latent
        weighted_latent = 1 / 0.18215 * weighted_latent
        # Decode the weighted latent
        decoded_image = pipe.vae.decode(weighted_latent).sample
        # Normalize the decoded image
        decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
        # Convert to CPU and then to numpy array
        decoded_image = decoded_image.cpu().permute(0, 2, 3, 1).numpy()[0]
    
    # Convert to PIL Image
    decoded_image_pil = Image.fromarray((decoded_image * 255).astype(np.uint8))
    # Resize the output image to the desired output size
    decoded_image_pil = decoded_image_pil.resize(output_image_size, Image.LANCZOS)
    
    # Save the weighted interpolated image
    if output_format == "jpg":
        decoded_image_pil.save(output_path, "JPEG", quality=image_quality)
    else:
        decoded_image_pil.save(output_path, "PNG")
    
    return output_path

def extract_image_number(filename):
    """Extract the image number from a filename."""
    # Try different patterns
    patterns = [
        r'_(\d+)_',           # Matches _000123_
        r'parcels_(\d+)',      # Matches parcels_000123
        r'(\d+)\.jpg',         # Matches 000123.jpg
        r'(\d+)\.png'          # Matches 000123.png
    ]
    
    for pattern in patterns:
        match = re.search(pattern, filename)
        if match:
            return match.group(1).zfill(6)  # Ensure 6 digits with leading zeros
    
    # If no pattern matches, return the filename itself for debugging
    print(f"Could not extract number from {filename}")
    return None

def create_mapping(predictions):
    """Create a mapping from image numbers to their weights."""
    mapping = {}
    
    # Print a few examples of the filenames to help debug
    if predictions and len(predictions) > 0:
        print("Example filenames in predictions:")
        for i in range(min(5, len(predictions))):
            print(f"  {Path(predictions[i]['image_path']).name}")
    
    for pred in predictions:
        image_path = pred["image_path"]
        filename = Path(image_path).name
        image_num = extract_image_number(filename)
        
        if image_num:
            mapping[image_num] = {
                "Boston": pred["probabilities"]["Boston"],
                "Charlotte": pred["probabilities"]["Charlotte"],
                "Manhattan": pred["probabilities"]["Manhattan"],
                "Pittsburgh": pred["probabilities"]["Pittsburgh"]
            }
    
    return mapping

# Main execution
try:
    # Load predictions
    with open(predictions_file, 'r') as f:
        predictions = json.load(f)
    
    print(f"Loaded {len(predictions)} image predictions")
    
    # Create mapping from image numbers to weights
    number_to_weights = create_mapping(predictions)
    print(f"Created mapping for {len(number_to_weights)} images")
    
    # Load the pre-trained Stable Diffusion model with DDIM scheduler
    with tqdm(total=1, desc="Loading Stable Diffusion model") as pbar:
        scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
        pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
        pipe.tokenizer.clean_up_tokenization_spaces = True
        pipe = pipe.to(device)
        if device.type == "cuda":
            pipe.enable_attention_slicing()
        pbar.update(1)
    
    # Get a list of all image numbers by checking one model directory
    sample_model_dir = list(model_outputs.values())[0]
    image_numbers = []
    
    for file in os.listdir(sample_model_dir):
        if "_fake_B" in file:
            image_num = extract_image_number(file)
            if image_num:
                image_numbers.append(image_num)
    
    image_numbers.sort()
    print(f"Found {len(image_numbers)} image numbers to process")
    
    # Process each image number
    with tqdm(total=len(image_numbers), desc="Processing images") as pbar:
        for image_num in image_numbers:
            try:
                # Skip if no weights found for this image
                if image_num not in number_to_weights:
                    print(f"No weights found for image {image_num}, skipping")
                    pbar.update(1)
                    continue
                
                # Get weights for this image
                image_weights = number_to_weights[image_num]
                
                # Load and encode all model outputs for this image
                image_latents = []
                cities_in_order = ["Boston", "Charlotte", "Manhattan", "Pittsburgh"]
                weights_in_order = []
                
                for city in cities_in_order:
                    # Find the fake image file
                    model_dir = model_outputs[city]
                    fake_image_path = None
                    
                    for file in os.listdir(model_dir):
                        if f"_{image_num}_fake_B" in file:
                            fake_image_path = os.path.join(model_dir, file)
                            break
                    
                    if fake_image_path:
                        # Load and preprocess the image
                        image_tensor, _, _ = load_and_preprocess_image(fake_image_path)
                        
                        # Encode to latent
                        with torch.no_grad():
                            latent = pipe.vae.encode(image_tensor).latent_dist.sample() * 0.18215
                            image_latents.append(latent)
                            weights_in_order.append(image_weights[city])
                    else:
                        print(f"Warning: Could not find fake image for {city}, image number {image_num}")
                
                # Skip if we don't have all model outputs
                if len(image_latents) < len(cities_in_order):
                    print(f"Missing some model outputs for image {image_num}, skipping")
                    pbar.update(1)
                    continue
                
                # Generate the ensemble image
                output_filename = f"ensemble_stable-diffusion_{image_num}.{output_format}"
                output_path = os.path.join(output_base_dir, output_filename)
                
                generate_weighted_image(pipe, image_latents, weights_in_order, output_path)
                
                pbar.update(1)
                
            except Exception as e:
                print(f"Error processing image {image_num}: {str(e)}")
                pbar.update(1)
                continue
    
    print(f"All images processed. Results saved to {output_base_dir}")

except Exception as e:
    print(f"An error occurred: {e}")