# CubeDiff: Generating High-Quality 360° Panoramas

This notebook demonstrates how to use the CubeDiff model to generate high-quality 360° panoramas from text prompts and normal field-of-view (NFoV) images.

CubeDiff is based on the paper ["CubeDiff: Repurposing Diffusion-Based Image Models for Panorama Generation"](https://arxiv.org/pdf/2501.17162).

In [None]:
# !pip install diffusers==0.24.0 transformers==4.36.2 torch==2.1.2 torchvision==0.16.2 accelerate==0.25.0 \
#     opencv-python==4.8.1.78 matplotlib==3.8.2 tqdm==4.66.1 einops==0.7.0 huggingface_hub==0.19.4 opencv-python xformers requests pillow

## 1. Setup

First, let's import the necessary libraries and modules.

In [13]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display
from diffusers import StableDiffusionPipeline

# Import CubeDiff modules
from data import (
    CubemapDataset, 
    visualize_cubemap, 
    equirectangular_to_cubemap_batch, 
    cubemap_to_equirectangular_batch
)
from modules import (
    CubemapPositionalEncoding,
    GroupNormalizationSync,
    InflatedAttention,
    OverlappingEdgeProcessor,
    adapt_unet_for_cubemap
)
from model import CubeDiff, CubeDiffPipeline
from trainer import CubeDiffTrainer

import importlib
import model as md
importlib.reload(md)         # guarantees the notebook sees the new code
import modules as mds
importlib.reload(mds)         # guarantees the notebook sees the new code
import data as dt
importlib.reload(dt)         # guarantees the notebook sees the new code
import trainer as trn
importlib.reload(trn)         # guarantees the notebook sees the new code




# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 2. Testing the Cubemap Conversion

Let's first test the equirectangular-to-cubemap conversion and vice versa using the provided utilities.

In [14]:
%%time
# Load a sample equirectangular image
# Replace this with your own sample image path
sample_path = "sample_equirect.jpg"

if os.path.exists(sample_path):
    # Load the image
    equirect_img = np.array(Image.open(sample_path))
    
    # Display the equirectangular image
    plt.figure(figsize=(12, 6))
    plt.imshow(equirect_img)
    plt.title("Original Equirectangular Image")
    plt.axis('off')
    plt.show()
    
    # Convert to cubemap
    from cubediff_utils_v1 import improved_equirect_to_cubemap, optimized_cubemap_to_equirect
    cube_faces = improved_equirect_to_cubemap(equirect_img, face_size=256)
    
    # Create a tensor from the cubemap
    face_order = ['front', 'back', 'left', 'right', 'top', 'bottom']
    cubemap_list = []
    
    for face_name in face_order:
        face = cube_faces[face_name]
        # Convert to tensor
        face_tensor = torch.from_numpy(face).float().permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
        if face_tensor.max() > 1.0:
            face_tensor = face_tensor / 255.0
        cubemap_list.append(face_tensor)
    
    cubemap_tensor = torch.stack(cubemap_list, dim=0)  # (6, C, H, W)
    
    # Visualize the cubemap
    visualize_cubemap(cubemap_tensor, title="Cubemap Faces")
    
    # Convert back to equirectangular
    equirect_recon = optimized_cubemap_to_equirect(cube_faces, *equirect_img.shape[:2])
    
    # Display the reconstructed equirectangular image
    plt.figure(figsize=(12, 6))
    plt.imshow(equirect_recon)
    plt.title("Reconstructed Equirectangular Image")
    plt.axis('off')
    plt.show()
else:
    print(f"Sample image not found at {sample_path}. Please provide a valid path.")

Sample image not found at sample_equirect.jpg. Please provide a valid path.
CPU times: user 87 μs, sys: 0 ns, total: 87 μs
Wall time: 77 μs


## 3. Initialize the CubeDiff Model

Now let's initialize the CubeDiff model using a pretrained Stable Diffusion model as the base.

In [15]:
%%time
# Initialize the CubeDiff model
model = CubeDiff(
    pretrained_model_name_or_path="stabilityai/stable-diffusion-2-base",
    use_fp16=True,  # Use mixed precision for faster inference
    device=device
)

Initializing CubeDiff with torch.float16 precision on cuda
Loading pretrained model components...
✓ Loaded tokenizer
✓ Loaded text encoder (frozen)
✓ Loaded VAE (frozen)
✓ Loaded UNet
✓ Loaded scheduler
Adapting UNet for cubemap processing...
Adapting attention layer: down_blocks.0.attentions.0.transformer_blocks.0.attn1
Adapting attention layer: down_blocks.0.attentions.0.transformer_blocks.0.attn2
Adapting attention layer: down_blocks.0.attentions.1.transformer_blocks.0.attn1
Adapting attention layer: down_blocks.0.attentions.1.transformer_blocks.0.attn2
Adapting attention layer: down_blocks.1.attentions.0.transformer_blocks.0.attn1
Adapting attention layer: down_blocks.1.attentions.0.transformer_blocks.0.attn2
Adapting attention layer: down_blocks.1.attentions.1.transformer_blocks.0.attn1
Adapting attention layer: down_blocks.1.attentions.1.transformer_blocks.0.attn2
Adapting attention layer: down_blocks.2.attentions.0.transformer_blocks.0.attn1
Adapting attention layer: down_blocks

## 4. Test Text-to-Panorama Generation

Let's test generating a panorama from a text prompt.

In [16]:
%%time
# Create a pipeline for easier inference
pipeline = CubeDiffPipeline(model, device)

# Text prompt
prompt = "A beautiful sunset over the ocean with mountains in the distance"

# Generate a panorama
print(f"Generating panorama from prompt: {prompt}")
output = pipeline(
    prompt=prompt,
    negative_prompt="blurry, low quality, distorted",
    num_inference_steps=50,
    guidance_scale=7.5,
    output_type="equirectangular",
    height=512,
    width=1024,
)

# Display the output
if isinstance(output, torch.Tensor):
    # Convert to image
    if output.ndim == 4:  # Batch of images
        output = output[0]  # Take first image
    
    # Move to CPU and convert to numpy
    output_np = output.cpu().permute(1, 2, 0).numpy()  # (C, H, W) -> (H, W, C)
    
    # Ensure values are in [0, 1]
    if output_np.max() > 1.0:
        output_np = output_np / 255.0
    
    # Display
    plt.figure(figsize=(15, 7.5))
    plt.imshow(output_np)
    plt.title(f"Generated Panorama: {prompt}")
    plt.axis('off')
    plt.show()
else:
    # If output is already an image, display it
    plt.figure(figsize=(15, 7.5))
    plt.imshow(output)
    plt.title(f"Generated Panorama: {prompt}")
    plt.axis('off')
    plt.show()

Generating panorama from prompt: A beautiful sunset over the ocean with mountains in the distance
Dimension mismatch: edge_region torch.Size([1, 4, 4, 128]), adj_edge_region torch.Size([1, 4, 64, 4])
Dimension mismatch: edge_region torch.Size([1, 4, 4, 128]), adj_edge_region torch.Size([1, 4, 64, 4])
Dimension mismatch: edge_region torch.Size([1, 4, 4, 128]), adj_edge_region torch.Size([1, 4, 64, 4])
Dimension mismatch: edge_region torch.Size([1, 4, 4, 128]), adj_edge_region torch.Size([1, 4, 64, 4])
Dimension mismatch: edge_region torch.Size([1, 4, 64, 4]), adj_edge_region torch.Size([1, 4, 4, 128])
Dimension mismatch: edge_region torch.Size([1, 4, 64, 4]), adj_edge_region torch.Size([1, 4, 4, 128])
Dimension mismatch: edge_region torch.Size([1, 4, 64, 4]), adj_edge_region torch.Size([1, 4, 4, 128])
Dimension mismatch: edge_region torch.Size([1, 4, 64, 4]), adj_edge_region torch.Size([1, 4, 4, 128])


RuntimeError: The size of tensor a (49152) must match the size of tensor b (8192) at non-singleton dimension 1

## 5. Test Image-to-Panorama Generation

Now let's test generating a panorama from a normal field-of-view (NFoV) image.

In [None]:
%%time
# Load a sample input image (normal field-of-view)
# Replace this with your own sample image path
input_path = "./images/bridge.jpg"

if os.path.exists(input_path):
    # Load the image
    input_img = Image.open(input_path).convert("RGB")
    
    # Display the input image
    plt.figure(figsize=(8, 8))
    plt.imshow(input_img)
    plt.title("Input NFoV Image")
    plt.axis('off')
    plt.show()
    
    # Convert to tensor
    from torchvision import transforms
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
    input_tensor = transform(input_img).unsqueeze(0).to(device)  # (1, 3, 512, 512)
    
    # Text prompt for conditioning
    prompt = "A wide view of the scene, 360 degrees panorama, high quality"
    
    # Generate a panorama from the input image
    print(f"Generating panorama from input image with prompt: {prompt}")
    output = pipeline(
        prompt=prompt,
        input_image=input_tensor,  # Provide the input image
        condition_face=0,  # Use the front face for conditioning
        num_inference_steps=50,
        guidance_scale=7.5,
        output_type="equirectangular",
        height=512,
        width=1024,
    )
    
    # Display the output
    if isinstance(output, torch.Tensor):
        # Convert to image
        if output.ndim == 4:  # Batch of images
            output = output[0]  # Take first image
        
        # Move to CPU and convert to numpy
        output_np = output.cpu().permute(1, 2, 0).numpy()  # (C, H, W) -> (H, W, C)
        
        # Ensure values are in [0, 1]
        if output_np.max() > 1.0:
            output_np = output_np / 255.0
        
        # Display
        plt.figure(figsize=(15, 7.5))
        plt.imshow(output_np)
        plt.title(f"Generated Panorama from Input Image")
        plt.axis('off')
        plt.show()
    else:
        # If output is already an image, display it
        plt.figure(figsize=(15, 7.5))
        plt.imshow(output)
        plt.title(f"Generated Panorama from Input Image")
        plt.axis('off')
        plt.show()
else:
    print(f"Input image not found at {input_path}. Please provide a valid path.")

## 6. Visualize the Cubemap Representation

Let's visualize the intermediate cubemap representation of the generated panorama.

In [None]:
%%time
# Generate a panorama in cubemap format
prompt = "A forest with a river and mountains in the distance"

# Generate a panorama with cubemap output
print(f"Generating cubemap from prompt: {prompt}")
cubemap_output = pipeline(
    prompt=prompt,
    num_inference_steps=50,
    guidance_scale=7.5,
    output_type="cubemap",  # Output as cubemap
    height=256,
    width=256,
)

# Visualize the cubemap
if isinstance(cubemap_output, torch.Tensor):
    # If batch of cubemaps, take the first one
    if cubemap_output.ndim == 5:  # (B, 6, C, H, W)
        cubemap_output = cubemap_output[0]  # (6, C, H, W)
    
    # Visualize
    visualize_cubemap(cubemap_output, title=f"Generated Cubemap: {prompt}")
    
    # Convert to equirectangular for comparison
    equirect = cubemap_to_equirectangular_batch(
        cubemap_output.unsqueeze(0).to(device),
        height=512,
        width=1024
    )
    
    # Display equirectangular
    equirect_np = equirect[0].cpu().permute(1, 2, 0).numpy()
    if equirect_np.max() > 1.0:
        equirect_np = equirect_np / 255.0
    
    plt.figure(figsize=(15, 7.5))
    plt.imshow(equirect_np)
    plt.title(f"Converted to Equirectangular")
    plt.axis('off')
    plt.show()

## 7. Setting Up for Training

If you have a dataset of panoramic images, you can train or fine-tune the CubeDiff model. Here's how to set up the training process.

In [None]:
%%time
# Example of setting up a dataset and dataloader
# Replace these paths with your actual data
panorama_paths = [
    "data/panorama1.jpg",
    "data/panorama2.jpg",
    # Add more paths...
]

caption_paths = [
    "data/caption1.txt",
    "data/caption2.txt",
    # Add more paths...
]

# Create dataset
# Note: Only execute this if you have actual data
if all(os.path.exists(path) for path in panorama_paths) and all(os.path.exists(path) for path in caption_paths):
    # Create train dataset
    train_dataset = CubemapDataset(
        image_paths=panorama_paths,
        caption_paths=caption_paths,
        face_size=128,  # Size of each cubemap face
    )
    
    # Create train dataloader
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=2,
    )
    
    # Test the dataset
    test_dataset(train_dataset, num_samples=1)
    
    # Create optimizer and scheduler
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable_params, lr=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
    
    # Create trainer
    trainer = CubeDiffTrainer(
        model=model,
        train_dataloader=train_dataloader,
        optimizer=optimizer,
        lr_scheduler=scheduler,
        device=device,
        output_dir="output",
    )
    
    print("Training setup complete. You can now run trainer, train() to start training.")
else:
    print("Training dataset not available. Skipping training setup.")

## 8. Examining the Model Architecture

Let's examine the architecture of the CubeDiff model in detail.

In [None]:
%%time
# Print model summary
print(f"CubeDiff Model Architecture Overview:")
print(f"==================================")
print(f"Text Encoder: {model.text_encoder.__class__.__name__}")
print(f"VAE: {model.vae.__class__.__name__}")
print(f"UNet: {model.unet.__class__.__name__}")
print(f"Scheduler: {model.scheduler.__class__.__name__}")
print(f"==================================")
print(f"Custom Modules:")
print(f"- CubemapPositionalEncoding")
print(f"- GroupNormalizationSync")
print(f"- OverlappingEdgeProcessor")
print(f"==================================")
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"==================================")

## 9. Understanding the Diffusion Process

Let's visualize the diffusion process step by step to understand how the model generates panoramas.

In [None]:
%%time
# Function to visualize the diffusion process
def visualize_diffusion_process(model, prompt, num_inference_steps=20, output_type="equirectangular"):
    """Visualize the diffusion process step by step."""
    # Set model to eval mode
    model.eval()
    
    # Process prompt
    if isinstance(prompt, str):
        prompt = [prompt]
    
    # Batch size
    batch_size = len(prompt)
    
    # Encode text
    text_embeddings = model.encode_text(prompt)
    uncond_embeddings = model.encode_text(["" for _ in range(batch_size)])
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    
    # Set timesteps
    model.scheduler.set_timesteps(num_inference_steps, device=model.device)
    timesteps = model.scheduler.timesteps
    
    # Prepare initial latents
    latents = model.prepare_latents(
        batch_size=batch_size,
        num_faces=6,
        height=256,
        width=256,
    )
    
    # Apply cubemap-specific processing
    latents = model.cubemap_pos_encoding(latents)
    latents = model.group_norm_sync(latents)
    
    # Initialize output list
    outputs = []
    
    # Denoising loop
    for i, t in enumerate(timesteps):
        # Duplicate for classifier-free guidance
        latent_model_input = torch.cat([latents] * 2)
        
        # Scale according to the scheduler
        latent_model_input = model.scheduler.scale_model_input(latent_model_input, t)
        
        # Reshape for UNet
        b, f, c, h, w = latent_model_input.shape
        reshaped_input = latent_model_input.view(b * f, c, h, w)
        
        # Predict noise
        with torch.no_grad():
            noise_pred = model.unet(reshaped_input, t, encoder_hidden_states=text_embeddings).sample
            noise_pred = noise_pred.view(b, f, c, h, w)
        
        # Apply classifier-free guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
        
        # Scheduler step
        latents = model.scheduler.step(noise_pred, t, latents).prev_sample
        
        # Apply cubemap-specific processing
        latents = model.cubemap_pos_encoding(latents)
        latents = model.group_norm_sync(latents)
        
        # Decode the current latents and add to outputs
        if i % (num_inference_steps // 5) == 0 or i == num_inference_steps - 1:
            # Decode current state
            current_output = model.decode_latents_to_image(latents)
            
            # Convert to equirectangular if needed
            if output_type == "equirectangular":
                current_output = cubemap_to_equirectangular_batch(
                    current_output,
                    height=512,
                    width=1024
                )
            
            outputs.append({
                "step": i,
                "timestep": t.item(),
                "output": current_output
            })
    
    # Plot all outputs
    num_outputs = len(outputs)
    fig, axes = plt.subplots(
        1, num_outputs, 
        figsize=(5 * num_outputs, 5),
        squeeze=False
    )
    
    for i, output_dict in enumerate(outputs):
        step = output_dict["step"]
        timestep = output_dict["timestep"]
        output = output_dict["output"]
        
        # Convert to numpy
        if output_type == "equirectangular":
            # Take first image from batch
            img = output[0].cpu().permute(1, 2, 0).numpy()
        else:
            # For cubemap visualization
            img = output[0]  # Will handle cubemap visualization separately
        
        # Ensure values are in [0, 1]
        if isinstance(img, np.ndarray) and img.max() > 1.0:
            img = img / 255.0
        
        # Display
        if output_type == "equirectangular":
            axes[0, i].imshow(img)
            axes[0, i].set_title(f"Step {step}\nTimestep {timestep}")
            axes[0, i].axis('off')
        else:
            # For cubemap, display a separate visualization
            if i == 0:
                # Only do the first one here, we'll handle the rest separately
                axes[0, i].imshow(np.zeros((10, 10, 3)))  # Placeholder
                axes[0, i].set_title(f"See separate cubemap visualizations")
                axes[0, i].axis('off')
                
                # Visualize each cubemap separately
                for j, output_dict in enumerate(outputs):
                    visualize_cubemap(
                        output_dict["output"][0],
                        title=f"Step {output_dict['step']}, Timestep {output_dict['timestep']}"
                    )
            else:
                axes[0, i].axis('off')  # Hide other axes
    
    plt.tight_layout()
    plt.show()
    
    return outputs

# Now let's visualize the diffusion process
prompt = "A mountains landscape with a lake"

# Run the visualization
try:
    diffusion_steps = visualize_diffusion_process(
        model=model,
        prompt=prompt,
        num_inference_steps=20,
        num_inference_steps=20,
        output_type="equirectangular"
    )
except Exception as e:
    print(f"Error visualizing diffusion process: {e}")

## 10. Conclusion

In this notebook, we've demonstrated the CubeDiff model for generating high-quality 360° panoramas. The key advantages of this approach include:

1. **Reusing existing diffusion models**: By adapting pretrained text-to-image diffusion models, CubeDiff leverages the capabilities of existing models without training from scratch.

2. **Cubemap representation**: Using cubemap instead of equirectangular projection reduces distortions and makes it easier to generate consistent panoramas.

3. **Specialized components for cubemap processing**:
   - Cubemap positional encoding for spatial understanding
   - Synchronized group normalization for consistent colors
   - Overlapping edge processor for seamless transitions
   - Inflated attention for cross-face relationships

4. **Flexible conditioning**: The model supports both text-to-panorama and image-to-panorama generation.

The implementation is structured in a modular way, making it easy to understand, modify, and extend.