In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm  # Use tqdm instead of tqdm.notebook if not in a notebook environment
import json
import warnings

# Parameters
vqvae_model_path = "models/vq-vae_16-80000.pth"  # Update with your model path
image_paths = [
    "../data/results/ma-boston_200250_fake_B.png",
    "../data/results/nc-charlotte_200250_fake_B.png",
    "../data/results/ny-manhattan_200250_fake_B.png",
    "../data/results/pa-pittsburgh_200250_fake_B.png"
]
num_steps = 50  # Increase for smoother interpolation
identifier = f"vqvae_4-images_{num_steps}-steps"  # Unique identifier for this run
output_dir = os.path.join("vqvae-output", identifier)
output_image_size = (512, 512)  # Width and height of output image
max_image_dimension = 1024  # Maximum dimension for resizing images

# Create a dictionary with all relevant parameters
params = {
    "identifier": identifier,
    "vqvae_model_path": vqvae_model_path,
    "num_steps": num_steps,
    "output_image_size": output_image_size,
    "max_image_dimension": max_image_dimension
}

# Define the path to save the JSON file
params_path = os.path.join(output_dir, "parameters.json")

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Save the parameters to a JSON file
with open(params_path, 'w') as f:
    json.dump(params, f, indent=4)

print(f"Parameters saved to {params_path}")

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def slerp(val, low, high):
    """Spherical linear interpolation for 4D tensors."""
    low_2d = low.view(low.shape[0], -1)
    high_2d = high.view(high.shape[0], -1)
    low_2d_norm = low_2d / torch.norm(low_2d, dim=1, keepdim=True)
    high_2d_norm = high_2d / torch.norm(high_2d, dim=1, keepdim=True)
    omega = torch.acos((low_2d_norm * high_2d_norm).sum(1).clamp(-1, 1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega).unsqueeze(1) / so.unsqueeze(1)) * low_2d + \
          (torch.sin(val * omega).unsqueeze(1) / so.unsqueeze(1)) * high_2d
    return res.view(low.shape)

def load_and_preprocess_image(image_path):
    # Open the image and convert to RGB
    image = Image.open(image_path).convert("RGB")
    
    # Get the original size
    original_size = image.size
    
    # Calculate the aspect ratio
    aspect_ratio = original_size[0] / original_size[1]
    
    # Determine the new size (ensuring it's divisible by 8)
    if aspect_ratio > 1:
        new_width = min(original_size[0], max_image_dimension)
        new_width = new_width - (new_width % 8)
        new_height = int(new_width / aspect_ratio)
        new_height = new_height - (new_height % 8)
    else:
        new_height = min(original_size[1], max_image_dimension)
        new_height = new_height - (new_height % 8)
        new_width = int(new_height * aspect_ratio)
        new_width = new_width - (new_width % 8)
    
    new_size = (new_width, new_height)
    
    # Create the transform
    transform = Compose([
        Resize(new_size),
        ToTensor(),
        Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Adjusted for 3 channels
    ])
    
    # Apply the transform
    image_tensor = transform(image).unsqueeze(0).to(device).to(torch.float32)  # Use float32 for VQ-VAE
    
    return image_tensor, new_size, original_size

# VQ-VAE Model Definitions (from your provided code)
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.tanh(self.conv2(x))
        return x

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, x):
        # x: [batch_size, embedding_dim, height, width]
        # Flatten the input
        flat_x = x.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
        # Compute distances
        distances = torch.cdist(flat_x, self.embedding.weight)
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1)
        # Quantize
        quantized = self.embedding(encoding_indices).view(x.shape)
        # Losses
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        # Straight-through estimator
        quantized = x + (quantized - x).detach()
        return quantized, loss, encoding_indices

class VQVAE(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=64, num_embeddings=1024, embedding_dim=64, commitment_cost=0.25):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, hidden_channels, embedding_dim)
        self.decoder = Decoder(embedding_dim, hidden_channels, in_channels)
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

    def encode(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        return quantized

    def decode(self, quantized):
        decoded = self.decoder(quantized)
        return decoded

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

def load_vqvae_model(model_path):
    # Load your VQ-VAE model
    checkpoint = torch.load(model_path, map_location=device)
    model = VQVAE(
        in_channels=3,
        hidden_channels=64,
        num_embeddings=1024,
        embedding_dim=64,
        commitment_cost=0.25
    )
    model.load_state_dict(checkpoint)  # Load the state dict directly
    model = model.to(device)
    model.eval()
    return model

try:
    # Suppress the FutureWarning from torch.load
    warnings.filterwarnings("ignore", category=FutureWarning)

    # Load the pre-trained VQ-VAE model
    with tqdm(total=1, desc="Loading VQ-VAE model") as pbar:
        vqvae_model = load_vqvae_model(vqvae_model_path)
        pbar.update(1)

    # Load and preprocess images
    images = []
    sizes = []
    original_sizes = []
    with tqdm(total=len(image_paths), desc="Loading and preprocessing images") as pbar:
        for image_path in image_paths:
            image, size, original_size = load_and_preprocess_image(image_path)
            images.append(image)
            sizes.append(size)
            original_sizes.append(original_size)
            pbar.update(1)

    print(f"Processed sizes: {sizes}")
    print(f"Original sizes: {original_sizes}")

    # Encode images to latents using VQ-VAE
    latents = []
    with tqdm(total=len(images), desc="Encoding images to latent representations") as pbar:
        with torch.no_grad():
            for image in images:
                latent = vqvae_model.encode(image)
                latents.append(latent)
                pbar.update(1)

    print(f"Shapes of latents: {[latent.shape for latent in latents]}")

    # Interpolate between latents
    print("Interpolating between latents...")
    alphas = np.linspace(0, 1, num_steps)
    interpolated_images = []

    with tqdm(total=num_steps, desc="Interpolating and Decoding") as pbar:
        for alpha in alphas:
            # Interpolate between the first and last latent
            interpolated_latent = slerp(alpha, latents[0], latents[-1])
            with torch.no_grad():
                # Decode the latents
                decoded_image = vqvae_model.decode(interpolated_latent)
                # Denormalize the decoded image
                decoded_image = (decoded_image + 1) / 2  # Shift from [-1, 1] to [0, 1]
                decoded_image = decoded_image.clamp(0, 1)
                # Convert to CPU and then to numpy array
                decoded_image = decoded_image.cpu().permute(0, 2, 3, 1).numpy()[0]

            # Convert to PIL Image
            decoded_image_pil = Image.fromarray((decoded_image * 255).astype(np.uint8))
            # Resize the output image to the desired output size
            decoded_image_pil = decoded_image_pil.resize(output_image_size, Image.LANCZOS)

            interpolated_images.append(decoded_image_pil)

            # Save the interpolated image
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, f"interpolated_{len(interpolated_images)}.png")
            decoded_image_pil.save(output_path, quality=95)

            pbar.update(1)

    print(f"Interpolation complete. {num_steps} images generated and saved in {output_dir}")

    # Plot results and save the plot as an image file
    print("Plotting results...")
    fig, axes = plt.subplots(1, len(interpolated_images), figsize=(20, 4))
    for ax, img in zip(axes, interpolated_images):
        ax.imshow(img)
        ax.axis('off')

    # Save the plot to a file
    plot_path = os.path.join(output_dir, "interpolation_steps.png")
    fig.savefig(plot_path, bbox_inches='tight')
    plt.show()
    print(f"Plot saved to {plot_path}")

except Exception as e:
    print(f"An error occurred: {e}")

Parameters saved to vqvae-output/vqvae_4-images_50-steps/parameters.json
Using device: cuda


Loading VQ-VAE model:   0%|          | 0/1 [00:00<?, ?it/s]

An error occurred: Error(s) in loading state_dict for VQVAE:
	Missing key(s) in state_dict: "encoder.conv1.weight", "encoder.conv1.bias", "encoder.conv2.weight", "encoder.conv2.bias", "decoder.conv1.weight", "decoder.conv1.bias", "decoder.conv2.weight", "decoder.conv2.bias", "vq_layer.embedding.weight". 
	Unexpected key(s) in state_dict: "fc_mu.weight", "fc_mu.bias", "fc_logvar.weight", "fc_logvar.bias", "decoder_input.weight", "decoder_input.bias", "encoder.0.weight", "encoder.0.bias", "encoder.2.weight", "encoder.2.bias", "encoder.4.weight", "encoder.4.bias", "decoder.0.weight", "decoder.0.bias", "decoder.2.weight", "decoder.2.bias", "decoder.4.weight", "decoder.4.bias". 



