# Appa: Unconditional Sample Generation

This notebook demonstrates how to generate unconditional atmospheric samples using a trained Appa model.

Based on the [Appa documentation](https://github.com/montefiore-sail/appa/wiki/Generating-Unconditional-Samples), this follows the blanket mechanism for generating prior trajectories.

## Setup and Imports


In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from omegaconf import OmegaConf
from einops import rearrange

# Add the appa module to the path
sys.path.append('/Users/randychase/Documents/PythonWorkspace/cbottle/appa_tio')

import appa
from appa.diffusion import create_denoiser, create_schedule
from appa.sampling import DDIMSampler, PCMSampler, DDPMSampler, LMSDiscreteSampler
from appa.save import load_auto_encoder, load_denoiser
from appa.data.datasets import LatentBlanketDataset
from appa.date import create_trajectory_timestamps
from appa.grid import create_icosphere

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## Configuration

Configure the generation parameters. You'll need to update the `model_path` to point to your trained model.


In [None]:
# Configuration - UPDATE THESE PATHS TO YOUR MODEL
config = {
    'model_path': '/path/to/your/trained/model',  # Update this path
    'model_target': 'best',  # Options: 'best', 'last'
    'diffusion': {
        'num_steps': 64,  # Number of denoising steps
        'sampler': {
            'type': 'lms',  # Options: 'pc', 'ddpm', 'ddim', 'rewind', 'lms'
            'config': {}
        }
    },
    'trajectory_sizes': [72],  # Size of trajectory in hours
    'num_samples_per_date': 2,  # Number of samples to generate
    'start_dates': [
        "2000-04-03 0h",
        "2000-04-20 12h"
    ],
    'blanket_overlap': 4,  # Overlap between blankets
    'precision': 'float16'  # Options: 'float32', 'float16', 'bfloat16'
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")


## Load Models

Load the trained autoencoder and denoiser models.


In [None]:
# Load autoencoder
print("Loading autoencoder...")
ae_model, ae_cfg = load_auto_encoder(config['model_path'], target=config['model_target'])
ae_model = ae_model.to(device)
ae_model.eval()
print(f"Autoencoder loaded: {ae_cfg.ae.name}")

# Load denoiser
print("Loading denoiser...")
denoiser, denoiser_cfg = load_denoiser(config['model_path'], target=config['model_target'])
denoiser = denoiser.to(device)
denoiser.eval()
print(f"Denoiser loaded: {denoiser_cfg.backbone.name}")

# Create noise schedule
schedule = create_schedule(denoiser_cfg.train, device=device)
print(f"Noise schedule: {denoiser_cfg.train.noise_schedule}")


## Setup Sampling

Configure the sampler based on the configuration.


In [None]:
# Create sampler
sampler_type = config['diffusion']['sampler']['type']
sampler_config = config['diffusion']['sampler']['config']
num_steps = config['diffusion']['num_steps']

if sampler_type == 'pc':
    sampler = PCMSampler(denoiser, schedule, **sampler_config)
elif sampler_type == 'ddpm':
    sampler = DDPMSampler(denoiser, schedule, **sampler_config)
elif sampler_type == 'ddim':
    sampler = DDIMSampler(denoiser, schedule, **sampler_config)
elif sampler_type == 'lms':
    sampler = LMSDiscreteSampler(denoiser, schedule, **sampler_config)
else:
    raise ValueError(f"Unknown sampler type: {sampler_type}")

print(f"Sampler created: {sampler_type}")
print(f"Number of steps: {num_steps}")


## Generate Unconditional Samples

Generate unconditional atmospheric samples for the specified dates and trajectory sizes.


In [None]:
# Parse start dates
start_dates = []
for date_str in config['start_dates']:
    # Parse date string like "2000-04-03 0h"
    date_part, hour_part = date_str.split()
    hour = int(hour_part.replace('h', ''))
    start_dates.append((date_part, hour))

print(f"Start dates: {start_dates}")
print(f"Trajectory sizes: {config['trajectory_sizes']}")
print(f"Samples per date: {config['num_samples_per_date']}")


In [None]:
# Generate samples for each trajectory size and start date
all_samples = {}

for trajectory_size in config['trajectory_sizes']:
    print(f"\nGenerating samples for trajectory size: {trajectory_size}h")
    
    trajectory_samples = {}
    
    for date_str, hour in start_dates:
        print(f"  Date: {date_str} {hour:02d}h")
        
        # Create timestamps for the trajectory
        timestamps = create_trajectory_timestamps(
            start_date=date_str,
            start_hour=hour,
            trajectory_size=trajectory_size,
            dt=1  # 1 hour timestep
        )
        
        # Generate multiple samples for this date
        date_samples = []
        
        for sample_idx in range(config['num_samples_per_date']):
            print(f"    Sample {sample_idx + 1}/{config['num_samples_per_date']}")
            
            # Generate random noise as starting point
            # The shape depends on your model configuration
            # This is a placeholder - you'll need to adjust based on your model
            batch_size = 1
            latent_dim = ae_cfg.ae.latent_channels  # Adjust based on your model
            
            # Create random noise
            z = torch.randn(batch_size, trajectory_size, latent_dim, device=device)
            
            # Generate sample using the sampler
            with torch.no_grad():
                sample = sampler.sample(
                    z,
                    num_steps=num_steps,
                    timestamps=timestamps
                )
            
            date_samples.append(sample.cpu())
        
        trajectory_samples[f"{date_str}_{hour:02d}h"] = torch.cat(date_samples, dim=0)
    
    all_samples[f"{trajectory_size}h"] = trajectory_samples

print("\nGeneration completed!")


## Decode Samples to Physical Space

Decode the generated latent samples back to physical atmospheric variables.


In [None]:
# Decode samples to physical space
decoded_samples = {}

for trajectory_size, trajectory_data in all_samples.items():
    print(f"Decoding samples for {trajectory_size}...")
    
    decoded_trajectory = {}
    
    for date_key, samples in trajectory_data.items():
        print(f"  Decoding {date_key}...")
        
        # Move samples to device for decoding
        samples = samples.to(device)
        
        # Decode each sample
        decoded_batch = []
        for i in range(samples.shape[0]):
            with torch.no_grad():
                # Decode the latent sample to physical space
                decoded = ae_model.decode(samples[i:i+1])
                decoded_batch.append(decoded.cpu())
        
        decoded_trajectory[date_key] = torch.cat(decoded_batch, dim=0)
    
    decoded_samples[trajectory_size] = decoded_trajectory

print("Decoding completed!")


## Save Results

Save the generated samples for later analysis.


In [None]:
# Create output directory
output_dir = Path("generated_samples")
output_dir.mkdir(exist_ok=True)

# Save latent samples
latent_path = output_dir / "latent_samples.pt"
torch.save(all_samples, latent_path)
print(f"Latent samples saved to: {latent_path}")

# Save decoded samples
decoded_path = output_dir / "decoded_samples.pt"
torch.save(decoded_samples, decoded_path)
print(f"Decoded samples saved to: {decoded_path}")

# Save configuration
config_path = output_dir / "generation_config.yaml"
OmegaConf.save(config, config_path)
print(f"Configuration saved to: {config_path}")


## Visualization

Create some basic visualizations of the generated samples.


In [None]:
# Basic visualization of sample statistics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, (trajectory_size, trajectory_data) in enumerate(decoded_samples.items()):
    if idx >= 4:
        break
    
    ax = axes[idx]
    
    # Get all samples for this trajectory size
    all_trajectory_samples = torch.cat(list(trajectory_data.values()), dim=0)
    
    # Plot mean and std across samples
    mean_values = all_trajectory_samples.mean(dim=(0, 1))  # Mean across batch and time
    std_values = all_trajectory_samples.std(dim=(0, 1))   # Std across batch and time
    
    ax.plot(mean_values.numpy(), label='Mean', alpha=0.7)
    ax.fill_between(range(len(mean_values)), 
                    (mean_values - std_values).numpy(), 
                    (mean_values + std_values).numpy(), 
                    alpha=0.3, label='±1σ')
    
    ax.set_title(f'Trajectory {trajectory_size} - Sample Statistics')
    ax.set_xlabel('Variable Index')
    ax.set_ylabel('Value')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print sample information
print("\nGenerated Sample Summary:")
for trajectory_size, trajectory_data in decoded_samples.items():
    print(f"\nTrajectory {trajectory_size}:")
    for date_key, samples in trajectory_data.items():
        print(f"  {date_key}: {samples.shape} samples")
        print(f"    Mean: {samples.mean().item():.4f}")
        print(f"    Std: {samples.std().item():.4f}")
        print(f"    Min: {samples.min().item():.4f}")
        print(f"    Max: {samples.max().item():.4f}")


## Next Steps

The generated samples can now be used for:

1. **Rendering**: Use the rendering scripts to visualize the atmospheric states
2. **Evaluation**: Compare against ground truth data for validation
3. **Conditional Generation**: Use these as starting points for data assimilation
4. **Forecasting**: Use as initial conditions for weather forecasting

For more advanced usage, see the [Appa documentation](https://github.com/montefiore-sail/appa/wiki) for:
- [Forecasting](https://github.com/montefiore-sail/appa/wiki/Forecasting)
- [Reanalysis and Filtering](https://github.com/montefiore-sail/appa/wiki/Reanalysis-and-Filtering)
- [Rendering States](https://github.com/montefiore-sail/appa/wiki/Rendering-States)
