In [2]:
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
import numpy as np
import os
from tqdm.auto import tqdm

# Parameters
model_id = "runwayml/stable-diffusion-v1-5"
image_paths = [
    "../data/results/ma-boston_200250_fake_B.png",
    "../data/results/nc-charlotte_200250_fake_B.png",
    "../data/results/ny-manhattan_200250_fake_B.png",
    "../data/results/pa-pittsburgh_200250_fake_B.png"
]
inference_steps = 15000  # increase for better image quality
identifier = f"stable-diffusion-weighted-interpolation_{inference_steps}-inference"  # Unique identifier for this run
output_dir = os.path.join("diffusion-output", identifier)
output_image_size = (512, 512)  # width and height of output image
max_image_dimension = 1024  # Maximum dimension for resizing images

# Example softmax weights (these should sum to 1)
weights = [0.25, 0.35, 0.15, 0.25]  # Replace with your actual weights

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def load_and_preprocess_image(image_path):
    """Load and preprocess an image."""
    # Open the image and convert to RGB
    image = Image.open(image_path).convert("RGB")
    
    # Get the original size
    original_size = image.size
    
    # Calculate the aspect ratio
    aspect_ratio = original_size[0] / original_size[1]
    
    # Determine the new size (ensuring it's divisible by 8 for the VAE)
    if aspect_ratio > 1:
        new_width = min(original_size[0], max_image_dimension)
        new_width = new_width - (new_width % 8)
        new_height = int(new_width / aspect_ratio)
        new_height = new_height - (new_height % 8)
    else:
        new_height = min(original_size[1], max_image_dimension)
        new_height = new_height - (new_height % 8)
        new_width = int(new_height * aspect_ratio)
        new_width = new_width - (new_width % 8)
    
    new_size = (new_width, new_height)
    
    # Create the transform
    transform = Compose([
        Resize(new_size),
        ToTensor(),
        Normalize([0.5], [0.5])
    ])
    
    # Apply the transform
    image_tensor = transform(image).unsqueeze(0).to(device).to(torch.float16)  # Ensure consistent data type
    
    return image_tensor

def slerp(val, low, high):
    """Spherical linear interpolation for latents."""
    low_2d = low.view(low.shape[0], -1)
    high_2d = high.view(high.shape[0], -1)
    low_2d_norm = low_2d / torch.norm(low_2d, dim=1, keepdim=True)
    high_2d_norm = high_2d / torch.norm(high_2d, dim=1, keepdim=True)
    omega = torch.acos((low_2d_norm * high_2d_norm).sum(1).clamp(-1, 1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low_2d + \
          (torch.sin(val * omega) / so).unsqueeze(1) * high_2d
    return res.view(low.shape)

try:
    # Load the pre-trained Stable Diffusion model with DDIM scheduler
    with tqdm(total=1, desc="Loading Stable Diffusion model") as pbar:
        scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
        pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)  # Use mixed precision
        pipe = pipe.to(device)
        if device.type == "cuda":
            pipe.enable_attention_slicing()
        pbar.update(1)

    # Load and preprocess images
    images = []
    sizes = []
    original_sizes = []
    with tqdm(total=len(image_paths), desc="Loading and preprocessing images") as pbar:
        for image_path in image_paths:
            image = load_and_preprocess_image(image_path)
            images.append(image)
            pbar.update(1)

    print(f"Processed sizes: {sizes}")
    print(f"Original sizes: {original_sizes}")

    # Encode images to latents
    latents = []
    with tqdm(total=len(images), desc="Encoding images to latent representations") as pbar:
        with torch.no_grad():
            for image in images:
                latent = pipe.vae.encode(image).latent_dist.sample() * 0.18215
                latents.append(latent)
                pbar.update(1)

    print(f"Shapes of latents: {[latent.shape for latent in latents]}")

    # Perform SLERP interpolation between latents based on weights
    print("Performing weighted SLERP interpolation...")
    weighted_latent = latents[0]
    for i in range(1, len(latents)):
        weighted_latent = slerp(weights[i], weighted_latent, latents[i])

    with torch.no_grad():
        # Scale the latent
        weighted_latent = 1 / 0.18215 * weighted_latent
        # Decode the weighted latent
        decoded_image = pipe.vae.decode(weighted_latent).sample
        # Normalize the decoded image
        decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
        # Convert to CPU and then to numpy array
        decoded_image = decoded_image.cpu().permute(0, 2, 3, 1).numpy()[0]

    # Convert to PIL Image
    decoded_image_pil = Image.fromarray((decoded_image * 255).astype(np.uint8))
    # Resize the output image to the desired output size
    decoded_image_pil = decoded_image_pil.resize(output_image_size, Image.LANCZOS)

    # Save the weighted interpolated image
    output_path = os.path.join(output_dir, "weighted_slerp_interpolated_image.png")
    decoded_image_pil.save(output_path, quality=95)

    print(f"Weighted SLERP interpolation image saved at: {output_path}")

except Exception as e:
    print(f"An error occurred: {e}")

Using device: mps


Loading Stable Diffusion model:   0%|          | 0/1 [00:00<?, ?it/s]

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Fetching 14 files: 100%|██████████| 14/14 [00:00<00:00, 56.13it/s]
Loading pipeline components...: 100%|██████████| 7/7 [00:14<00:00,  2.00s/it]
Loading Stable Diffusion model: 100%|██████████| 1/1 [00:15<00:00, 15.82s/it]
Loading and preprocessing images: 100%|██████████| 4/4 [00:00<00:00, 10.92it/s]


Processed sizes: []
Original sizes: []


Encoding images to latent representations: 100%|██████████| 4/4 [00:03<00:00,  1.24it/s]


Shapes of latents: [torch.Size([1, 4, 32, 32]), torch.Size([1, 4, 32, 32]), torch.Size([1, 4, 32, 32]), torch.Size([1, 4, 32, 32])]
Performing weighted SLERP interpolation...
Weighted SLERP interpolation image saved at: diffusion-output/stable-diffusion-weighted-interpolation_15000-inference/weighted_slerp_interpolated_image.png
