In [None]:
# Quick Demo toggle: set to True for a fast, CPU-friendly run
quick_demo = False  # change to True to run a tiny demo
demo_note = "Quick Demo is ON" if quick_demo else "Quick Demo is OFF"
print(demo_note)

# SpectraNet Redshift — Example Walkthrough

This notebook is a beginner-friendly example to understand and run SpectraNet for spectroscopic redshift regression.

What you'll do:
- Set up environment and imports

- Configure dataset paths and training hyperparameters

- Peek at a few samples to understand the input format

- Build and summarize the model

- Train/evaluate (optionally with a quick demo mode)

- Run a small inference demo and visualize predictions vs ground truth

> Tip: If you're new or running on CPU, enable the Quick Demo mode in the next cell to make it fast.

In [None]:
# Setup and imports (GPU-ready and repo-aware)
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
max_length = 1_000_000
InteractiveShell.instance().display_formatter.formatters['text/plain'].max_seq_length = int(max_length)
import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

from collections import deque
import os
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

# Import SpectraNet redshift model from this repo structure
try:
    from AppleCider.models.SpectraNetRedshift import build_spec_red_model
    from AppleCider.preprocess.data_loader_redshift import create_data_loaders
    from AppleCider.preprocess.utils_redshift import (
        set_seed as train_set_seed,
        get_device as train_get_device,
        train_one_epoch_regression,
        validate_regression,
        build_optimizer,
        build_scheduler,
        early_stopping,
        clean_memory,
    )
except ModuleNotFoundError:
    # Fallback if run from repository root
    sys.path.append(os.path.abspath('.'))
    from AppleCider.models.SpectraNetRedshift import build_spec_red_model
    from AppleCider.preprocess.data_loader_redshift import create_data_loaders
    from AppleCider.preprocess.utils_redshift import (
        set_seed as train_set_seed,
        get_device as train_get_device,
        train_one_epoch_regression,
        validate_regression,
        build_optimizer,
        build_scheduler,
        early_stopping,
        clean_memory,
    )

# Utility functions for this example (kept lightweight)

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def make_dummy_loader(batch_size=8, num_batches=2, length=4096, seed=0):
    """Create a tiny synthetic dataset shaped like spectra: [B, 1, 4096]."""
    rng = np.random.RandomState(seed)
    for _ in range(num_batches):
        x = rng.normal(0, 1, size=(batch_size, 1, length)).astype(np.float32)
        # Create a smooth target roughly correlated with low-frequency content
        y = (x.mean(axis=-1, keepdims=False) + 0.1 * rng.randn(batch_size, 1)).clip(0, None).astype(np.float32)
        yield torch.from_numpy(x), torch.from_numpy(y.squeeze(1))


def quick_metrics(y_true: torch.Tensor, y_pred: torch.Tensor):
    y_true = y_true.detach().cpu().float()
    y_pred = y_pred.detach().cpu().float()
    mse = torch.mean((y_pred - y_true) ** 2).item()
    mae = torch.mean(torch.abs(y_pred - y_true)).item()
    bias = torch.mean(y_pred - y_true).item()
    return {"mse": mse, "mae": mae, "bias": bias}


def plot_one_spectrum(x: torch.Tensor, title: str = "Example spectrum"):
    x = x.detach().cpu().numpy().reshape(-1)
    plt.figure(figsize=(10, 3))
    plt.plot(x, lw=1.0)
    plt.xlabel('Wavelength index')
    plt.ylabel('Flux (arb. units)')
    plt.title(title)
    plt.tight_layout()
    plt.show()


# Example: minimal forward pass demo for sanity check

def forward_pass_demo(seed=42):
    set_seed(seed)
    device = get_device()
    model = build_spec_red_model(config={}).to(device)
    model.eval()

    # Use a tiny synthetic batch
    loader = make_dummy_loader(batch_size=4, num_batches=1, length=4096, seed=seed)
    x, y = next(iter(loader))
    x, y = x.to(device), y.to(device)

    with torch.no_grad():
        y_pred = model(x)
    metrics = quick_metrics(y, y_pred)

    print(f"Device: {device}")
    print(f"Model parameters: {count_parameters(model):,}")
    print("Synthetic batch: x ->", tuple(x.shape), ", y ->", tuple(y.shape))
    print("Quick metrics on synthetic data:", {k: round(v, 6) for k, v in metrics.items()})

    # Visualize one spectrum
    plot_one_spectrum(x[0], title="Synthetic spectrum example")

    return model, (x, y, y_pred), metrics

In [None]:
# 3) Configuration (paths and hyperparameters)

# Base directory for your redshift dataset (adjust to your environment)
# Expected files prepared by preprocessing:
#   <redshift_dir>/processed/train.pt
#   <redshift_dir>/processed/val.pt
#   <redshift_dir>/processed/test.pt

data_dir_base = os.path.join(project_root, 'data')  # change if your data lives elsewhere
redshift_dir = os.path.join(data_dir_base, 'dataset_for_redshift_with_error')
processed_dir = os.path.join(redshift_dir, 'processed')
model_save_dir = os.path.join(redshift_dir, 'model')
os.makedirs(model_save_dir, exist_ok=True)

final_config = {
    "batch_size": 256,
    "learning_rate": 1.6e-4,
    "weight_decay": 1e-5,

    # scheduler
    "warmup_epochs": 10,
    "T_0": 5,
    "T_mult": 1,
    "eta_min": 1e-5,
    "start_factor": 1e-6,
    "end_factor": 1.0,

    "num_workers": 4,
    "epochs": 100,
    "patience": 5,

    # Processed PT files for loaders
    "train_dir": os.path.join(processed_dir, 'train.pt'),
    "val_dir": os.path.join(processed_dir, 'val.pt'),
    "test_dir": os.path.join(processed_dir, 'test.pt'),

    "model_save_dir": model_save_dir,
    "seed": 42,
}

print("Config summary:")
for k, v in final_config.items():
    print(f"  {k}: {v}")

## 1) Setup

Below we import dependencies, wire up this repo for imports, and define a few tiny helpers for reproducibility, quick metrics, and plotting.

## 2) Peek at the data

We'll first run a tiny synthetic example with spectra shaped like `[B, 1, 4096]`. If you have processed tensors, you can wire your real loaders here.

In [None]:
# Create a small synthetic batch and visualize

set_seed(final_config.get("seed", 42))
loader = make_dummy_loader(batch_size=4 if quick_demo else 8, num_batches=1, length=4096, seed=final_config.get("seed", 42))
x_demo, y_demo = next(iter(loader))
print("Synthetic demo shapes:", x_demo.shape, y_demo.shape)
plot_one_spectrum(x_demo[0], title="Synthetic spectrum example (peek)")

## 3) Build the model

We use `build_spec_red_model(config)` to construct the SpectraNet regression model. It takes spectra shaped `[B, 1, 4096]` and outputs a non-negative redshift via a softplus head.

In [None]:
# Build model and inspect parameters
device = get_device()
model = build_spec_red_model(config={}).to(device)
print(model.__class__.__name__)
print(f"Trainable parameters: {count_parameters(model):,}")

_ = model.eval()
# quick forward with synthetic batch
x_demo = x_demo.to(device)
with torch.no_grad():
    y_hat_demo = model(x_demo)
print("Forward OK. Output shape:", tuple(y_hat_demo.shape))

## 4) Inference demo

Let’s run a tiny forward pass on synthetic data to see predictions vs. targets and basic metrics. If you have a trained checkpoint, you can load it before inference.

In [None]:
# Run the minimal forward pass demo (synthetic)

_ = forward_pass_demo(seed=final_config.get("seed", 42))

# Optional: scatter of y vs y_pred for the small synthetic batch
device = get_device()
model = build_spec_red_model(config={}).to(device)
model.eval()
loader = make_dummy_loader(batch_size=16 if quick_demo else 64, num_batches=1, length=4096, seed=final_config.get("seed", 42))
x, y = next(iter(loader))
x, y = x.to(device), y.to(device)
with torch.no_grad():
    y_pred = model(x)
m = quick_metrics(y, y_pred)
print({k: round(v, 6) for k, v in m.items()})
plt.figure(figsize=(4, 4))
plt.scatter(y.cpu().numpy(), y_pred.cpu().numpy(), s=12, alpha=0.6)
plt.xlabel("Target redshift (synthetic)")
plt.ylabel("Predicted redshift")
plt.title("Tiny synthetic scatter")
plt.tight_layout()
plt.show()

## Use your own data

To run with real spectra instead of the synthetic demo:
- Place your processed tensors in: `data/dataset_for_redshift_with_error/processed/{train,val,test}.pt`
- Each tensor should contain spectra shaped `[B, 1, 4096]` and targets shaped `[B]` (float redshift).
- Update `data_dir_base` or `redshift_dir` in the config cell if your data lives elsewhere.
- Replace the synthetic loader with your own DataLoader wiring before training/evaluation.

## How SpectraNet works (conceptual)

- Input: 1D spectra of fixed length 4096 as `[B, 1, 4096]`.
- Backbone: multi-branch 1D conv blocks per stage with several kernel sizes to capture narrow/wide features (lines and continuum).
- Normalization: LayerNorm over channel dimension to stabilize across wavelengths.
- Downsampling: periodic max-pooling to compress sequence length.
- Head: global pooling + MLP regressor, with a softplus at the end to ensure non-negative redshift.

This example uses a tiny synthetic dataset. With real data, ensure your wavelength grid and flux normalization match the expected format.

## Troubleshooting

- ImportError for `AppleCider`: ensure this notebook runs from the `notebooks/` folder so `..` points to repo root.
- CUDA out of memory: set `quick_demo = True` to use small batches and CPU.
- Different spectrum length: resample or pad/crop to 4096 and keep the channel dimension as 1.
- Bad metrics on synthetic data: this is expected; use real data and training for meaningful results.

## 5) Train and validate (full run)

Now we’ll use the real processed datasets and train SpectraNet end-to-end on GPU (if available). We log per-epoch metrics and keep the best model by validation loss.

In [None]:
# Build loaders, model, optimizer, scheduler, and run training
train_set_seed(final_config["seed"])
device = train_get_device()
print("Using device:", device)

# Data
train_loader, val_loader, test_loader = create_data_loaders({
    "train_dir": final_config["train_dir"],
    "val_dir": final_config["val_dir"],
    "test_dir": final_config["test_dir"],
    "batch_size": final_config["batch_size"],
    "num_workers": final_config["num_workers"],
})

# Model and optimization
model = build_spec_red_model(config={}).to(device)
optimizer = build_optimizer(model, final_config)
scheduler = build_scheduler(optimizer, final_config)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
loss_fn = torch.nn.MSELoss()

best_val = float('inf')
epochs_no_improve = 0
best_path = os.path.join(final_config["model_save_dir"], "best_model.pth")

for epoch in range(1, final_config["epochs"] + 1):
    print(f"\nEpoch {epoch}/{final_config['epochs']}")
    train_loss = train_one_epoch_regression(model, train_loader, optimizer, device, scaler, loss_fn)
    val_stats = validate_regression(model, val_loader, device, loss_fn, plot=False)
    val_loss = val_stats['loss']
    print({"train_loss": round(train_loss, 6), **{k: round(v, 6) for k, v in val_stats.items()}})
    
    scheduler.step()
    
    # Early stopping check
    if val_loss < best_val:
        best_val = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), best_path)
        print(f"Saved best model to {best_path}")
    else:
        epochs_no_improve += 1
        if early_stopping(epochs_no_improve, final_config['patience']):
            print("Early stopping triggered.")
            break
    
    clean_memory()

## 6) Evaluate on test set

Load the best checkpoint and compute full regression metrics. Also plot predictions vs. ground truth with the ±0.15(1+z) bands.

In [None]:
# Load best model and evaluate on test set with plot
best_model = build_spec_red_model(config={}).to(device)
best_state = torch.load(os.path.join(final_config["model_save_dir"], "best_model.pth"), map_location=device)
best_model.load_state_dict(best_state)
best_model.eval()

test_stats = validate_regression(best_model, test_loader, device, torch.nn.MSELoss(), plot=True)
print({k: round(v, 6) for k, v in test_stats.items()})