In [1]:
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
## Check MPS availability on Apple Silicon
import torch

if torch.backends.mps.is_available():
    print("MPS backend is available.")
else:
    print("MPS backend is not available.")

In [None]:
image1_path = "/Users/ls/Library/CloudStorage/GoogleDrive-l.schrage@northeastern.edu/Shared drives/Drawing Participation/Million Neighborhoods/Generated Images/ma-boston/parcels/parcels_50.jpg" 
image2_path = "/Users/ls/Library/CloudStorage/GoogleDrive-l.schrage@northeastern.edu/Shared drives/Drawing Participation/Million Neighborhoods/Generated Images/pa-pittsburgh/parcels/parcels_263.jpg"

In [None]:
# Set device to MPS if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load the pre-trained Stable Diffusion model
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to(device)

# Enable attention slicing to reduce memory usage
pipe.enable_attention_slicing()

# Define transformations for input images
transform = Compose([
    Resize((256, 256)),  # Further reduce input size to save memory
    ToTensor(),
    Normalize([0.5], [0.5])
])

# Function to load and transform images
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0).to(device)

# Load images
print("Loading and transforming images...")
image1 = load_image(image1_path).float()
image2 = load_image(image2_path).float()

# Clear MPS cache to free up memory
torch.mps.empty_cache()

# Encode images to latents
print("Encoding images to latent representations...")
with torch.no_grad():
    latents1 = pipe.vae.encode(image1).latent_dist.sample()
    latents2 = pipe.vae.encode(image2).latent_dist.sample()

# Interpolate between latents
print("Interpolating between latents...")
num_steps = 10
alphas = np.linspace(0, 1, num_steps)
interpolated_images = []

for idx, alpha in enumerate(alphas):
    print(f"Interpolating step {idx + 1}/{num_steps}...")
    interpolated_latent = (1 - alpha) * latents1 + alpha * latents2
    decoded_image = pipe.vae.decode(interpolated_latent).sample[0]
    interpolated_images.append(decoded_image)

# Plot results
print("Plotting results...")
fig, axes = plt.subplots(1, num_steps, figsize=(20, 5))
for ax, img in zip(axes, interpolated_images):
    ax.imshow((img.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5))
    ax.axis('off')
plt.show()