# Appa: Unconditional Sample Generation (Clean Version)

This notebook demonstrates how to generate unconditional atmospheric samples using a trained Appa model.

Based on the [Appa documentation](https://github.com/montefiore-sail/appa/wiki/Generating-Unconditional-Samples), this follows the blanket mechanism for generating prior trajectories.

## Setup and Imports


In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from omegaconf import OmegaConf
from einops import rearrange

# Add the appa module to the path
sys.path.append('/home/azureuser/cloudfiles/code/Users/appa_tio')

import appa
from appa.diffusion import create_denoiser, create_schedule
from appa.sampling import DDIMSampler, PCSampler, DDPMSampler, LMSSampler, RewindDDIMSampler, select_sampler
from appa.save import load_auto_encoder, load_denoiser
from appa.data.datasets import LatentBlanketDataset
from appa.date import create_trajectory_timestamps
from appa.grid import create_icosphere

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## Configuration

Configure the generation parameters. Update the paths to match your downloaded model directories.


In [None]:
# Configuration following the official wiki specifications
config = {
    'ae_model_path': '/home/azureuser/cloudfiles/code/Users/randy.chase/appa_models/autoencoders/workshop/0/latents/workshop/ae',  # Autoencoder path
    'denoiser_model_path': '/home/azureuser/cloudfiles/code/Users/randy.chase/appa_models/autoencoders/workshop/0/latents/workshop/denoisers/workshop/0',  # Denoiser path
    'model_target': 'best',  # Options: 'best', 'last'
    'diffusion': {
        'num_steps': 64,  # Number of denoising steps (defaults to model's validation denoising steps)
        'sampler': {
            'type': 'lms',  # Options: 'pc', 'ddpm', 'ddim', 'rewind', 'lms'
            'config': {}
        }
    },
    'trajectory_sizes': [72],  # Size of trajectory in hours (unpadded)
    'num_samples_per_date': 2,  # Number of samples to generate
    'start_dates': [
        "2000-04-03 0h",
        "2000-04-20 12h"
    ],
    'blanket_overlap': 4,  # Overlap between blankets (following wiki guidance)
    'precision': 'float16'  # Options: 'float32', 'float16', 'bfloat16'
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

# Check if both model paths exist
ae_path = config['ae_model_path']
denoiser_path = config['denoiser_model_path']

print(f"\nChecking model paths:")
print(f"Autoencoder path: {ae_path}")
if os.path.exists(ae_path):
    print(f"✓ Autoencoder path exists")
    print("Contents:")
    for item in os.listdir(ae_path):
        print(f"  - {item}")
else:
    print(f"✗ Autoencoder path does not exist")

print(f"\nDenoiser path: {denoiser_path}")
if os.path.exists(denoiser_path):
    print(f"✓ Denoiser path exists")
    print("Contents:")
    for item in os.listdir(denoiser_path):
        print(f"  - {item}")
else:
    print(f"✗ Denoiser path does not exist")


## Load Models

Load the trained autoencoder and denoiser models with smart detection of model file naming conventions.


In [None]:
# Smart model loading that handles different naming conventions
from pathlib import Path

print("Loading autoencoder...")
# For Hugging Face models, use "model" not "model_best"
ae_model = load_auto_encoder(
    path=Path(config['ae_model_path']),
    model_name="model",  # This should be "model" for Hugging Face
    device=device,
    eval_mode=True
)
print(f"Autoencoder loaded successfully")

# Get latent channels from the loaded model
print("Getting latent dimensions from loaded model...")
latent_shape = ae_model.latent_shape
print(f"Model latent shape: {latent_shape}")

# Extract latent channels from the shape
if len(latent_shape) == 3:  # ConvAE: (h, w, channels)
    latent_channels = latent_shape[2]
elif len(latent_shape) == 2:  # GraphAE: (nodes, channels)
    latent_channels = latent_shape[1]
else:
    raise ValueError(f"Unexpected latent shape: {latent_shape}")

print(f"Latent channels: {latent_channels}")

print("Loading denoiser...")
# For denoiser, check if we have model_best.pth or model.pth
denoiser_path = Path(config['denoiser_model_path'])

# Check what model files exist
model_best_path = denoiser_path / "model_best.pth"
model_path = denoiser_path / "model.pth"

if model_best_path.exists():
    print("Found model_best.pth, using standard load_denoiser")
    best = config['model_target'] == 'best'
    denoiser = load_denoiser(
        path=denoiser_path,
        best=best,
        device=device
    )
elif model_path.exists():
    print("Found model.pth, loading manually")
    # Load config
    from omegaconf import OmegaConf
    with open(denoiser_path / "config.yaml", "r") as f:
        denoiser_cfg = OmegaConf.load(f)
    
    # Create denoiser manually
    from appa.diffusion import create_denoiser
    denoiser = create_denoiser(denoiser_cfg, denoiser_cfg, device=device)
    
    # Load weights
    checkpoint = torch.load(model_path, map_location=device)
    denoiser.backbone.load_state_dict(checkpoint)
    denoiser.backbone.eval()
else:
    raise FileNotFoundError(f"No model file found in {denoiser_path}")

print(f"Denoiser loaded successfully")

# Create noise schedule
from omegaconf import OmegaConf
with open(denoiser_path / "config.yaml", "r") as f:
    denoiser_cfg = OmegaConf.load(f)

schedule = create_schedule(denoiser_cfg.train, device=device)
print(f"Noise schedule: {denoiser_cfg.train.noise_schedule}")

# Handle precision following the wiki guidance (generate.py lines 207-208, 226-227)
precision = getattr(torch, config['precision'])
use_bfloat16 = precision == torch.bfloat16

if use_bfloat16:
    torch.set_default_dtype(torch.bfloat16)
    print(f"Set default dtype to {torch.get_default_dtype()}")

print(f"Using precision: {config['precision']}")
print(f"Use bfloat16: {use_bfloat16}")


## Setup Sampling

Configure the sampler based on the configuration.


In [None]:
# Create sampler using the select_sampler function
sampler_type = config['diffusion']['sampler']['type']
sampler_config = config['diffusion']['sampler']['config']
num_steps = config['diffusion']['num_steps']

# Use the select_sampler function to get the correct sampler class
SamplerClass = select_sampler(sampler_type)

# Create the sampler instance
sampler = SamplerClass(
    denoiser=denoiser,
    schedule=schedule,
    steps=num_steps,
    **sampler_config
)

print(f"Sampler created: {sampler_type}")
print(f"Number of steps: {num_steps}")
print(f"Sampler class: {SamplerClass.__name__}")


## Generate Unconditional Samples

Generate unconditional atmospheric samples for the specified dates and trajectory sizes.


In [None]:
# Parse start dates
start_dates = []
for date_str in config['start_dates']:
    # Parse date string like "2000-04-03 0h"
    date_part, hour_part = date_str.split()
    hour = int(hour_part.replace('h', ''))
    start_dates.append((date_part, hour))

print(f"Start dates: {start_dates}")
print(f"Trajectory sizes: {config['trajectory_sizes']}")
print(f"Samples per date: {config['num_samples_per_date']}")


In [None]:
# Generate samples for each trajectory size and start date
# Following the official generate.py pattern with proper trajectory padding
all_samples = {}

for unpadded_trajectory_size in config['trajectory_sizes']:
    print(f"\nGenerating samples for trajectory size: {unpadded_trajectory_size}h")
    
    # CRITICAL: Implement trajectory padding logic from generate.py lines 64-68
    # This ensures we have a valid number of blankets
    blanket_size = denoiser_cfg.train.blanket_size
    blanket_stride = blanket_size - config['blanket_overlap']
    
    # Pad trajectory to fit blankets properly (following generate.py logic)
    padded_trajectory_size = max(blanket_size, unpadded_trajectory_size)
    while (padded_trajectory_size - blanket_size) % blanket_stride != 0:
        padded_trajectory_size += 1
    
    print(f"  Unpadded size: {unpadded_trajectory_size}h")
    print(f"  Padded size: {padded_trajectory_size}h")
    print(f"  Blanket size: {blanket_size}, stride: {blanket_stride}")
    
    trajectory_samples = {}
    
    for date_str, hour in start_dates:
        print(f"  Date: {date_str} {hour:02d}h")
        
        # Create timestamps for the PADDED trajectory (following generate.py line 247-249)
        timestamps = create_trajectory_timestamps(
            start_day=date_str,
            start_hour=hour,
            traj_size=padded_trajectory_size,
            dt=1  # 1 hour timestep
        )[None]  # Add batch dimension like in generate.py
        
        # Generate multiple samples for this date
        date_samples = []
        
        for sample_idx in range(config['num_samples_per_date']):
            print(f"    Sample {sample_idx + 1}/{config['num_samples_per_date']}")
            
            # Generate sample following the EXACT Appa generate.py pattern
            # Handle precision as in generate.py lines 284-288
            def _generate_sample():
                from appa.diffusion import TrajectoryDenoiser, Denoiser
                from functools import partial
                import math
                
                # Get latent shape and state size (following generate.py lines 221-224)
                latent_shape = ae_model.latent_shape
                if len(latent_shape) == 3:
                    latent_shape = latent_shape[0] * latent_shape[1], latent_shape[2]
                state_size = math.prod(latent_shape)
                
                # Create Denoiser wrapper first (following generate.py lines 251-253)
                denoise = Denoiser(denoiser.backbone).cuda()
                
                # Handle precision properly (following generate.py lines 252-253)
                if use_bfloat16:
                    denoise = denoise.to(torch.bfloat16)
                
                # Create TrajectoryDenoiser (following generate.py lines 255-262)
                trajectory_denoiser = TrajectoryDenoiser(
                    denoise,
                    blanket_size=blanket_size,
                    blanket_stride=blanket_stride,
                    state_size=state_size,
                    distributed=False,  # Single GPU for notebook (vs True in distributed version)
                    pass_blanket_ids=False,
                )
                
                # Bind date parameter using partial (following generate.py line 264)
                conditioned_denoiser = partial(trajectory_denoiser, date=timestamps.cuda())
                
                # Create sampler with conditioned denoiser (following generate.py lines 268-276)
                sampler = SamplerClass(
                    denoiser=conditioned_denoiser,
                    schedule=schedule,
                    steps=num_steps,
                    silent=False,
                    **sampler_config
                )
                
                # Generate random noise and scale by max sigma (following generate.py lines 277-279)
                x1 = torch.randn(len(timestamps), padded_trajectory_size * state_size, device=device)
                samp_start = (x1 * schedule.sigma_tmax().cuda()).flatten(1).cuda()
                
                # Sample using the sampler (following generate.py line 279)
                sample = sampler(samp_start).reshape((-1, padded_trajectory_size, *latent_shape))
                
                # CRITICAL: Trim to unpadded trajectory size (following generate.py line 292)
                sample = sample[:, :unpadded_trajectory_size]
                
                return sample
            
            # Handle precision as in generate.py lines 284-288
            if precision != torch.float16:
                sample = _generate_sample()
            else:
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    sample = _generate_sample()
            
            date_samples.append(sample.cpu())
        
        trajectory_samples[f"{date_str}_{hour:02d}h"] = torch.cat(date_samples, dim=0)
    
    all_samples[f"{unpadded_trajectory_size}h"] = trajectory_samples

print("\nGeneration completed!")


## Decode Samples to Physical Space

Decode the generated latent samples back to physical atmospheric variables.


In [None]:
# Decode samples to physical space
decoded_samples = {}

for trajectory_size, trajectory_data in all_samples.items():
    print(f"Decoding samples for {trajectory_size}...")
    
    decoded_trajectory = {}
    
    for date_key, samples in trajectory_data.items():
        print(f"  Decoding {date_key}...")
        
        # Move samples to device for decoding
        samples = samples.to(device)
        
        # Decode each sample
        decoded_batch = []
        for i in range(samples.shape[0]):
            with torch.no_grad():
                # Decode the latent sample to physical space
                decoded = ae_model.decode(samples[i:i+1])
                decoded_batch.append(decoded.cpu())
        
        decoded_trajectory[date_key] = torch.cat(decoded_batch, dim=0)
    
    decoded_samples[trajectory_size] = decoded_trajectory

print("Decoding completed!")


## Save Results

Save the generated samples for later analysis.


In [None]:
# Create output directory
output_dir = Path("generated_samples")
output_dir.mkdir(exist_ok=True)

# Save latent samples
latent_path = output_dir / "latent_samples.pt"
torch.save(all_samples, latent_path)
print(f"Latent samples saved to: {latent_path}")

# Save decoded samples
decoded_path = output_dir / "decoded_samples.pt"
torch.save(decoded_samples, decoded_path)
print(f"Decoded samples saved to: {decoded_path}")

# Save configuration
config_path = output_dir / "generation_config.yaml"
OmegaConf.save(config, config_path)
print(f"Configuration saved to: {config_path}")


## Visualization

Create some basic visualizations of the generated samples.


In [None]:
# Basic visualization of sample statistics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, (trajectory_size, trajectory_data) in enumerate(decoded_samples.items()):
    if idx >= 4:
        break
    
    ax = axes[idx]
    
    # Get all samples for this trajectory size
    all_trajectory_samples = torch.cat(list(trajectory_data.values()), dim=0)
    
    # Plot mean and std across samples
    mean_values = all_trajectory_samples.mean(dim=(0, 1))  # Mean across batch and time
    std_values = all_trajectory_samples.std(dim=(0, 1))   # Std across batch and time
    
    ax.plot(mean_values.numpy(), label='Mean', alpha=0.7)
    ax.fill_between(range(len(mean_values)), 
                    (mean_values - std_values).numpy(), 
                    (mean_values + std_values).numpy(), 
                    alpha=0.3, label='±1σ')
    
    ax.set_title(f'Trajectory {trajectory_size} - Sample Statistics')
    ax.set_xlabel('Variable Index')
    ax.set_ylabel('Value')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print sample information
print("\nGenerated Sample Summary:")
for trajectory_size, trajectory_data in decoded_samples.items():
    print(f"\nTrajectory {trajectory_size}:")
    for date_key, samples in trajectory_data.items():
        print(f"  {date_key}: {samples.shape} samples")
        print(f"    Mean: {samples.mean().item():.4f}")
        print(f"    Std: {samples.std().item():.4f}")
        print(f"    Min: {samples.min().item():.4f}")
        print(f"    Max: {samples.max().item():.4f}")


## Next Steps

The generated samples can now be used for:

1. **Rendering**: Use the rendering scripts to visualize the atmospheric states
2. **Evaluation**: Compare against ground truth data for validation
3. **Conditional Generation**: Use these as starting points for data assimilation
4. **Forecasting**: Use as initial conditions for weather forecasting

For more advanced usage, see the [Appa documentation](https://github.com/montefiore-sail/appa/wiki) for:
- [Forecasting](https://github.com/montefiore-sail/appa/wiki/Forecasting)
- [Reanalysis and Filtering](https://github.com/montefiore-sail/appa/wiki/Reanalysis-and-Filtering)
- [Rendering States](https://github.com/montefiore-sail/appa/wiki/Rendering-States)
