# MRE-PINN Training on Lightning AI GPU

This notebook is configured to run on Lightning AI with GPU support.

## Setup Instructions:
1. Upload the entire MRE-PINN folder to Lightning AI Studios
2. Select a GPU runtime (e.g., T4, A10G)
3. The MRE-PINN conda environment should have all dependencies installed
4. The notebook will automatically use GPU if available

## Setup Environment

# Note: Make sure you've uploaded the entire MRE-PINN repository folder to Lightning AI

In [None]:
# Setup environment and imports
import sys
import os
import pathlib
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# Navigate to parent directory to access mre_pinn module
notebook_dir = pathlib.Path.cwd()
repo_root = notebook_dir.parent if notebook_dir.name == 'lightning-ai-training' else notebook_dir
sys.path.insert(0, str(repo_root))

# Configure DeepXDE backend
os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

# Import MRE-PINN
import mre_pinn

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 2**30:.2f} GiB")
else:
    print("Warning: No GPU detected. Training will be slow on CPU.")

print(f"\nRepository root: {repo_root}")
print(f"MRE-PINN module loaded from: {mre_pinn.__file__}")

## Download and Preprocess Data

Download the BIOQIC simulation dataset and convert it to xarray format.

Note: The download will automatically skip if the data already exists (thanks to the improved download check!).

# Create data directories relative to repo root
data_dir = repo_root / 'data' / 'BIOQIC'
download_dir = data_dir / 'downloads'
processed_dir = data_dir / 'fem_box'

download_dir.mkdir(parents=True, exist_ok=True)
processed_dir.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {data_dir}")
print(f"Download directory: {download_dir}")
print(f"Processed directory: {processed_dir}")

# Download and process data
bioqic = mre_pinn.data.BIOQICFEMBox(str(download_dir))
bioqic.download()  # Will skip if already exists
bioqic.load_mat()
bioqic.preprocess()
dataset = bioqic.to_dataset()
dataset.save_xarrays(str(processed_dir))

In [None]:
# Create data directories
data_dir = pathlib.Path('data/BIOQIC')
download_dir = data_dir / 'downloads'
processed_dir = data_dir / 'fem_box'

download_dir.mkdir(parents=True, exist_ok=True)
processed_dir.mkdir(parents=True, exist_ok=True)

# Download and process data
bioqic = mre_pinn.data.BIOQICFEMBox(str(download_dir))
bioqic.download()
bioqic.load_mat()
bioqic.preprocess()
dataset = bioqic.to_dataset()
dataset.save_xarrays(str(processed_dir))

## Load and Visualize Data

In [None]:
# Load example data
frequency = 90  # Hz
example = mre_pinn.data.MREExample.load_xarrays(str(processed_dir), frequency)

# Display metadata and statistics
print("\nMetadata:")
print(example.metadata)
print("\nDescriptive Statistics:")
print(example.describe())

In [None]:
# Visualize wave field (static plot for Lightning AI)
example.view('wave', ax_height=3)

## Evaluate Baseline Methods

In [None]:
# Evaluate AHI baseline
mre_pinn.baseline.eval_ahi_baseline(example, frequency=frequency)
example.view('mre', 'direct', ax_height=3, polar=True, vmax=20e3)

## Configure PINN Model

In [None]:
# Construct PDE
pde = mre_pinn.pde.WaveEquation.from_name('hetero', omega=frequency)

In [None]:
# Create PINN architecture
pinn = mre_pinn.model.MREPINN(
    example,
    omega=frequency,
    n_layers=5,
    n_hidden=128,
    activ_fn='ss',  # sin activation
    polar_input=False
)
print(pinn)

## Compile Model with GPU Support

In [None]:
# Configure training model with GPU device
model = mre_pinn.training.MREPINNModel(
    example, pinn, pde,
    loss_weights=[1, 0, 0, 1e-8],
    pde_warmup_iters=1000,
    pde_step_iters=500,
    pde_step_factor=10,
    pde_init_weight=1e-10,
    n_points=1024,
    device=device  # Automatically use GPU if available
)
model.compile(optimizer='adam', lr=1e-4, loss=mre_pinn.training.losses.msae_loss)

## Benchmark Performance

In [None]:
# Benchmark model performance
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True  # Enable cudnn benchmarking for better GPU performance
model.benchmark(10)

## Setup Test Evaluator

In [None]:
# Create test evaluator for periodic evaluation
test_eval = mre_pinn.testing.TestEvaluator(
    test_every=100,
    save_every=1000,
    save_prefix='LIGHTNING_AI',
    interact=False  # Disable interactive mode for Lightning AI
)
test_eval.model = model
test_eval.test()

## Train Model on GPU

This will train the model for 100,000 iterations using GPU acceleration.

In [None]:
%%time
# Reset GPU memory stats if available
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

# Train the model
print(f"Starting training on {device}...")
model.train(10000, display_every=100, callbacks=[test_eval])

# Display GPU memory usage
if torch.cuda.is_available():
    peak_memory = torch.cuda.max_memory_allocated() / 2**30
    print(f"\nPeak GPU memory: {peak_memory:.2f} GiB")
    print(f"GPU utilization was efficient!")
else:
    print("\nWarning: Training completed on CPU (no GPU detected)")

## Evaluate Final Results

In [None]:
# Final evaluation
test_eval.test()

In [None]:
# Display metrics
print("Training Metrics:")
print(test_eval.metrics)

## Save Model Checkpoint

In [None]:
# Save trained model
checkpoint_path = pathlib.Path('checkpoints')
checkpoint_path.mkdir(exist_ok=True)

torch.save({
    'model_state_dict': model.pinn.state_dict(),
    'optimizer_state_dict': model.optimizer.state_dict(),
    'frequency': frequency,
    'device': device,
}, checkpoint_path / 'mre_pinn_lightning_ai.pth')

print(f"Model checkpoint saved to: {checkpoint_path / 'mre_pinn_lightning_ai.pth'}")

## Download Results

Download the trained model and results back to your local machine from Lightning AI.

In [None]:
# Create a zip file of all results
!zip -r lightning_ai_results.zip checkpoints/ LIGHTNING_AI_*.png LIGHTNING_AI_*.pkl
print("Results packaged. Download 'lightning_ai_results.zip' from Lightning AI file browser.")