# CubeDiff: Inference and Evaluation

This notebook demonstrates the inference pipeline for CubeDiff and evaluates the results:

1. Load trained model
2. Generate panoramas from text prompts
3. Generate panoramas from single images
4. Visualize and evaluate the results
5. Compare with baseline methods

In [None]:
import os
import sys
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
import json
import cv2

# Add parent directory to path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# Import custom modules
from inference.pipeline import CubeDiffPipeline
from data.preprocessing import equirect_to_cubemap, cubemap_to_equirect

## 1. Load Trained Model

In [None]:
# Set paths
pretrained_model_name = "runwayml/stable-diffusion-v1-5"
checkpoint_path = "../outputs/cubediff_mini/final_model/model.pt"

# For demonstration, if we don't have a trained model yet, we'll use the base model
if not os.path.exists(checkpoint_path):
    print("Trained model not found, using base Stable Diffusion model.")
    checkpoint_path = None
else:
    print(f"Using trained model from {checkpoint_path}")

# Initialize the pipeline
pipeline = CubeDiffPipeline(
    pretrained_model_name=pretrained_model_name,
    checkpoint_path=checkpoint_path,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

## 2. Generate Panoramas from Text Prompts

In [None]:
# Define test prompts
test_prompts = [
    "A beautiful sunset over a mountain range, with vibrant orange and purple sky",
    "A cozy living room with a fireplace, comfortable furniture, and large windows",
    "A lush tropical beach with palm trees, crystal clear water, and white sand",
    "A futuristic city skyline at night with neon lights and flying vehicles",
    "A dense forest with tall trees, sunlight filtering through the leaves, and a small stream"
]

# Generate and visualize panoramas for each prompt
for i, prompt in enumerate(test_prompts):
    print(f"Generating panorama for prompt: {prompt}")
    
    # Generate panorama
    panorama = pipeline.generate(
        prompt=prompt,
        negative_prompt="low quality, blurry, distorted",
        num_inference_steps=30,  # Use fewer steps for faster inference during testing
        guidance_scale=7.5,
        height=512,
        width=512,
        output_type="pil"
    )
    
    # Display panorama
    plt.figure(figsize=(15, 5))
    plt.imshow(np.array(panorama))
    plt.title(f"Prompt: {prompt}")
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Save panorama
    os.makedirs("../outputs/samples", exist_ok=True)
    panorama.save(f"../outputs/samples/panorama_{i}.jpg")

## 3. Visualize Cubemap Faces

In [None]:
# Select a prompt to visualize individual cube faces
test_prompt = "A beautiful mountain landscape with a lake and forest in the foreground"

# Generate cubemap latents
latents = pipeline.generate(
    prompt=test_prompt,
    negative_prompt="low quality, blurry, distorted",
    num_inference_steps=30,
    guidance_scale=7.5,
    height=512,
    width=512,
    output_type="latent"
)

# Extract and visualize individual faces
face_names = ['front', 'right', 'back', 'left', 'top', 'bottom']
cube_faces = []

with torch.no_grad():
    for i in range(6):
        face_latent = latents[0, i].unsqueeze(0)  # Add batch dimension
        face_image = pipeline.vae.decode(face_latent / 0.18215).sample
        face_image = (face_image / 2 + 0.5).clamp(0, 1)
        face_image = face_image[0].cpu().permute(1, 2, 0).numpy()
        cube_faces.append(face_image)

# Plot the cubemap faces
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
for i, (face_name, face_img) in enumerate(zip(face_names, cube_faces)):
    row, col = i // 3, i % 3
    axs[row, col].imshow(face_img)
    axs[row, col].set_title(face_name)
    axs[row, col].axis('off')
plt.suptitle(f"Cubemap Faces for: {test_prompt}")
plt.tight_layout()
plt.show()

# Convert cubemap to equirectangular and visualize
equirect = cubemap_to_equirect(np.array(cube_faces), 1024, 2048)

plt.figure(figsize=(15, 5))
plt.imshow(equirect)
plt.title(f"Equirectangular Panorama: {test_prompt}")
plt.axis('off')
plt.tight_layout()
plt.show()

## 4. Evaluate Face Consistency and Stitching Quality

In [None]:
# Helper function to check consistency across face boundaries
def evaluate_face_consistency(cube_faces):
    """Evaluate the consistency between adjacent cube faces."""
    # Define adjacent face pairs to check
    # Format: (face1_idx, face2_idx, face1_edge, face2_edge)
    # Edges: 0=top, 1=right, 2=bottom, 3=left
    adjacent_pairs = [
        (0, 1, 1, 3),  # front-right
        (1, 2, 1, 3),  # right-back
        (2, 3, 1, 3),  # back-left
        (3, 0, 1, 3),  # left-front
        (0, 4, 0, 2),  # front-top
        (0, 5, 2, 0),  # front-bottom
        # Add more pairs as needed
    ]
    
    # Calculate consistency scores
    scores = []
    
    for face1_idx, face2_idx, face1_edge, face2_edge in adjacent_pairs:
        face1 = cube_faces[face1_idx]
        face2 = cube_faces[face2_idx]
        
        # Extract edge pixels from each face
        if face1_edge == 0:  # Top edge
            edge1 = face1[0, :, :]
        elif face1_edge == 1:  # Right edge
            edge1 = face1[:, -1, :]
        elif face1_edge == 2:  # Bottom edge
            edge1 = face1[-1, :, :]
        else:  # Left edge
            edge1 = face1[:, 0, :]
        
        if face2_edge == 0:  # Top edge
            edge2 = face2[0, :, :]
        elif face2_edge == 1:  # Right edge
            edge2 = face2[:, -1, :]
        elif face2_edge == 2:  # Bottom edge
            edge2 = face2[-1, :, :]
        else:  # Left edge
            edge2 = face2[:, 0, :]
        
        # Ensure edges are in the same orientation
        if (face1_edge + face2_edge) % 2 == 1:  # Perpendicular edges
            edge2 = np.flip(edge2)
        
        # Calculate MSE between edges
        mse = np.mean((edge1 - edge2) ** 2)
        scores.append(mse)
        
    return scores

# Evaluate consistency for the generated cubemap
consistency_scores = evaluate_face_consistency(cube_faces)
print("Face boundary consistency scores (MSE, lower is better):")
for i, score in enumerate(consistency_scores):
    print(f"Edge {i}: {score:.6f}")

print(f"\nAverage consistency score: {np.mean(consistency_scores):.6f}")

## 5. Compare with Baseline Methods

In [None]:
# For baseline comparison, we'll use a simple approach of generating multiple images and stitching them
def generate_baseline_panorama(prompt):
    """Generate a baseline panorama by stitching individual images."""
    # Load standard Stable Diffusion pipeline
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", 
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Generate 4 horizontal images with similar content
    images = []
    directions = ["front view of", "right view of", "back view of", "left view of"]
    
    for direction in directions:
        full_prompt = f"{direction} {prompt}"
        image = pipe(full_prompt, guidance_scale=7.5).images[0]
        images.append(np.array(image))
    
    # Stitch images horizontally (very naive approach, just for demonstration)
    stitched = np.concatenate(images, axis=1)
    
    return stitched

# Generate and compare panoramas
test_prompt = "A mountain landscape with snow-capped peaks and green valleys"

# Generate CubeDiff panorama
cubediff_panorama = pipeline.generate(
    prompt=test_prompt,
    negative_prompt="low quality, blurry, distorted",
    num_inference_steps=30,
    guidance_scale=7.5,
    height=512,
    width=512,
    output_type="np"
)

# Generate baseline panorama
baseline_panorama = generate_baseline_panorama(test_prompt)

# Visualize and compare
plt.figure(figsize=(15, 10))

plt.subplot(2, 1, 1)
plt.imshow(cubediff_panorama)
plt.title("CubeDiff Panorama")
plt.axis('off')

plt.subplot(2, 1, 2)
plt.imshow(baseline_panorama)
plt.title("Baseline Panorama (Naive Stitching)")
plt.axis('off')

plt.suptitle(f"Comparison for prompt: {test_prompt}")
plt.tight_layout()
plt.show()