In [None]:
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.notebook import tqdm
from transformers import CLIPImageProcessor
import json
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW

# Custom Dataset Class to Load Data from Multiple Folders
class ImageDataset(Dataset):
    def __init__(self, image_folders, transform=None):
        self.image_paths = []
        self.transform = transform
        for folder in image_folders:
            for filename in os.listdir(folder):
                if filename.endswith(".png") or filename.endswith(".jpg"):
                    self.image_paths.append(os.path.join(folder, filename))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# Parameters
model_id = "runwayml/stable-diffusion-v1-5"
image_folders = [
    "../data/ma-boston/parcel",
    "../data/nc-charlotte/parcel",
    "../data/ny-manhattan/parcel",
    "../data/pa-pittsburgh/parcel"
]
batch_size = 4
num_steps = 50  # increase for smoother interpolation
num_epochs = 3  # fine-tuning for 3 epochs
learning_rate = 5e-6
output_image_size = (512, 512)  # width and height of output image
max_image_dimension = 1024  # Maximum dimension for resizing images

# Define the path to save the model
output_dir = os.path.join("fine-tuned-output")
os.makedirs(output_dir, exist_ok=True)

# Define the image transform
transform = Compose([
    Resize((max_image_dimension, max_image_dimension)),
    ToTensor(),
    Normalize([0.5], [0.5])
])

# Dataset and DataLoader
dataset = ImageDataset(image_folders, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Set device: Add MPS support
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple MPS for training.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA for training.")
else:
    device = torch.device("cpu")
    print("Using CPU for training.")

# Load pre-trained Stable Diffusion model
with tqdm(total=1, desc="Loading Stable Diffusion model") as pbar:
    scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
    pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
    pipe = pipe.to(device)
    if device.type == "cuda":
        pipe.enable_attention_slicing()  # Optimize memory on CUDA
    pbar.update(1)

# Define optimizer for fine-tuning
optimizer = AdamW(pipe.unet.parameters(), lr=learning_rate)

# Fine-tuning loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for images in tqdm(dataloader, desc=f"Fine-tuning"):
        images = images.to(device).to(torch.float16)

        # Encode the images to latents
        latents = pipe.vae.encode(images).latent_dist.sample() * 0.18215
        
        # Generate noise and forward pass through the model
        noise = torch.randn_like(latents).to(device)
        latents_noisy = scheduler.add_noise(latents, noise, torch.tensor([0.1]).to(device))

        # Calculate loss (using L2 loss as an example)
        optimizer.zero_grad()
        latent_pred = pipe.unet(latents_noisy, noise)
        loss = torch.nn.functional.mse_loss(latent_pred, latents_noisy)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} complete. Loss: {loss.item()}")

# Save the fine-tuned model
pipe.save_pretrained(output_dir)
print(f"Fine-tuning complete. Model saved to {output_dir}")

# Now, you can use the fine-tuned model for image generation