In [3]:
import torch
from torch.nn import functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import json
import gc

# Parameters
model_id = "taming-transformers/vqgan_imagenet_f16_16384"
image_paths = [
    "../data/results/ma-boston_200250_fake_B.png",
    "../data/results/nc-charlotte_200250_fake_B.png",
    "../data/results/ny-manhattan_200250_fake_B.png",
    "../data/results/pa-pittsburgh_200250_fake_B.png"
]
num_steps = 50  # increase for smoother interpolation
identifier = f"vqgan_{num_steps}-steps"  # Unique identifier for this run
output_dir = os.path.join("vqgan-output", identifier)
output_image_size = (128, 128)  # Reduced size to save memory
max_image_dimension = 512  # Reduced maximum dimension for resizing images
batch_size = 2  # Process images in smaller batches

# Create a dictionary with all relevant parameters
params = {
    "identifier": identifier,
    "model_id": model_id,
    "num_steps": num_steps,
    "output_image_size": output_image_size,
    "max_image_dimension": max_image_dimension,
    "batch_size": batch_size
}

# Define the path to save the JSON file
params_path = os.path.join(output_dir, "parameters.json")

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Save the parameters to a JSON file
with open(params_path, 'w') as f:
    json.dump(params, f, indent=4)

print(f"Parameters saved to {params_path}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def slerp(val, low, high):
    """Spherical linear interpolation."""
    omega = torch.acos((low/torch.norm(low, dim=1, keepdim=True) * high/torch.norm(high, dim=1, keepdim=True)).sum(1)).unsqueeze(1)
    so = torch.sin(omega)
    return (torch.sin((1.0-val)*omega) / so) * low + (torch.sin(val*omega) / so) * high

class VQModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1),
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

def load_vqgan_model():
    model = VQModel().to(device)
    return model.eval()

def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    s = min(image.size)
    
    transform = Compose([
        Resize(s),
        CenterCrop(s),
        Resize(output_image_size),
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    return transform(image).unsqueeze(0)

def decode_latents(model, latents):
    latents = latents.to(device)
    return model.decode(latents)

def process_in_batches(model, images, batch_size):
    latents = []
    for i in range(0, len(images), batch_size):
        batch = torch.cat(images[i:i+batch_size]).to(device)
        with torch.no_grad():
            latent = model.encode(batch)
        latents.append(latent.cpu())  # Move latents to CPU to save GPU memory
        del batch
        torch.cuda.empty_cache()
    return latents

try:
    # Load the VQGAN model
    with tqdm(total=1, desc="Loading VQGAN model") as pbar:
        model = load_vqgan_model()
        pbar.update(1)

    # Load and preprocess images
    images = []
    with tqdm(total=len(image_paths), desc="Loading and preprocessing images") as pbar:
        for image_path in image_paths:
            image = preprocess_image(image_path)
            images.append(image)
            pbar.update(1)

    # Encode images to latents
    latents = process_in_batches(model, images, batch_size)
    latents = torch.cat(latents)
    
    print(f"Shapes of latents: {latents.shape}")

    # Interpolate between latents
    print("Interpolating between latents...")
    alphas = np.linspace(0, 1, num_steps)
    interpolated_images = []
    
    with tqdm(total=num_steps, desc="Interpolating and Decoding") as pbar:
        for alpha in alphas:
            # Interpolate between the first and last latent
            interpolated_latent = slerp(alpha, latents[0].unsqueeze(0), latents[-1].unsqueeze(0))
            interpolated_latent = interpolated_latent.to(device)
            with torch.no_grad():
                # Decode the latents
                decoded_image = decode_latents(model, interpolated_latent)
                # Normalize the decoded image
                decoded_image = (decoded_image.clamp(-1, 1) + 1) / 2
                # Convert to CPU and then to numpy array
                decoded_image = decoded_image.cpu().squeeze(0).permute(1, 2, 0).numpy()
            
            # Convert to PIL Image
            decoded_image_pil = Image.fromarray((decoded_image * 255).astype(np.uint8))
            
            interpolated_images.append(decoded_image_pil)
            
            # Save the interpolated image
            output_path = os.path.join(output_dir, f"interpolated_{len(interpolated_images)}.png")
            decoded_image_pil.save(output_path, quality=95)
            
            pbar.update(1)
            
            # Clear some memory
            del interpolated_latent, decoded_image
            torch.cuda.empty_cache()
            gc.collect()

    print(f"Interpolation complete. {num_steps} images generated and saved in {output_dir}")

    # Plot results and save the plot as an image file
    print("Plotting results...")
    fig, axes = plt.subplots(1, num_steps, figsize=(20, 4))
    for ax, img in zip(axes, interpolated_images):
        ax.imshow(img)
        ax.axis('off')
    
    # Save the plot to a file
    plot_path = os.path.join(output_dir, "interpolation_steps.png")
    fig.savefig(plot_path, bbox_inches='tight')
    plt.close(fig)  # Close the figure to free up memory
    print(f"Plot saved to {plot_path}")

except Exception as e:
    print(f"An error occurred: {e}")

finally:
    # Clean up
    torch.cuda.empty_cache()
    gc.collect()

Parameters saved to vqgan-output/vqgan_50-steps/parameters.json
Using device: cuda


Loading VQGAN model: 100%|██████████| 1/1 [00:00<00:00, 48.04it/s]
Loading and preprocessing images: 100%|██████████| 4/4 [00:00<00:00, 236.34it/s]


Shapes of latents: torch.Size([4, 128, 32, 32])
Interpolating between latents...


Interpolating and Decoding: 100%|██████████| 50/50 [00:06<00:00,  7.95it/s]


Interpolation complete. 50 images generated and saved in vqgan-output/vqgan_50-steps
Plotting results...
Plot saved to vqgan-output/vqgan_50-steps/interpolation_steps.png


: 