# üåû Solar Flare Forecasting API - Google Colab

This notebook sets up a complete solar flare forecasting system using NASA IMPACT's pretrained Surya model.

**Features:**
- üöÄ Automated data and model download
- üîÆ 24-hour solar flare probability forecasting
- üåê REST API with ngrok tunnel for remote access
- üîê API key authentication

**Usage:**
1. Run all cells sequentially
2. Copy the ngrok URL when displayed
3. Use the local client script to make API calls

---

## üì¶ Step 1: Environment Setup & Dependencies

In [None]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

# Install required packages
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers huggingface-hub einops timm
!pip install -q flask flask-cors pyngrok
!pip install -q netCDF4 xarray h5py scipy numpy pandas matplotlib
!pip install -q PyYAML tqdm peft

print("‚úÖ All dependencies installed successfully!")

In [None]:
# Check GPU availability
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  Warning: No GPU detected. Inference will be slower on CPU.")

## üîÑ Step 2: Clone Repository & Setup

In [None]:
import os
import sys

# Clone the Surya repository
if not os.path.exists('/content/Surya'):
    !git clone https://github.com/NASA-IMPACT/Surya.git
    print("‚úÖ Repository cloned successfully")
else:
    print("‚ÑπÔ∏è  Repository already exists")

# Change to the solar flare forecasting directory
os.chdir('/content/Surya/downstream_examples/solar_flare_forcasting')
print(f"üìÇ Working directory: {os.getcwd()}")

# Add to Python path
sys.path.insert(0, '/content/Surya/downstream_examples/solar_flare_forcasting')
sys.path.insert(0, '/content/Surya')

print("‚úÖ Python path configured")

## üíæ Step 3: Download Data & Pretrained Models

This will download:
- Surya foundation model weights
- Solar flare task-specific weights
- SDO data and scaling factors
- Benchmark dataset

In [None]:
# Login to Hugging Face (optional but recommended)
from huggingface_hub import login

# If you have a Hugging Face token, enter it here
# login(token="your_token_here")

# OR run this cell and follow the prompt:
# login()

print("‚ÑπÔ∏è  You can skip HF login for public models, or login for better rate limits")

In [None]:
# Execute the download script
import subprocess

print("üì• Starting download... This may take 5-10 minutes")
print("" + "="*60)

# Run the download script
result = subprocess.run(['bash', 'download_data.sh'], capture_output=True, text=True)

if result.returncode == 0:
    print("‚úÖ Data and models downloaded successfully!")
else:
    print("‚ö†Ô∏è  Download completed with warnings (this is usually OK)")
    
print("\nüìä Downloaded assets:")
!ls -lh assets/

## ü§ñ Step 4: Load Pretrained Model & Setup Inference

In [None]:
import torch
import yaml
import numpy as np
import torch.nn.functional as F
from pathlib import Path

# Import necessary modules from the Surya repository
from surya.utils.data import build_scalers
from surya.utils.distributed import set_global_seed
from dataset import SolarFlareDataset
from finetune import get_model, apply_peft_lora, custom_collate_fn

# Set random seed for reproducibility
set_global_seed(42)

# Load configuration
with open('config_infer.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Load scalers
config["data"]["scalers"] = yaml.safe_load(open(config["data"]["scalers_path"], "r"))
scalers = build_scalers(info=config["data"]["scalers"])

# Set dtype
if config["dtype"] == "float16":
    config["dtype"] = torch.float16
elif config["dtype"] == "bfloat16":
    config["dtype"] = torch.bfloat16
elif config["dtype"] == "float32":
    config["dtype"] = torch.float32
else:
    raise NotImplementedError("Please choose from [float16,bfloat16,float32]")

print("üìã Configuration loaded")
print(f"   Model type: {config['model']['model_type']}")
print(f"   Task: Solar Flare Forecasting (Binary Classification)")
print(f"   Using LoRA: {config['model']['use_lora']}")

# Initialize model using the repository's helper function
print("\nüîß Initializing model...")
model = get_model(config, wandb_logger=None)

# Apply LoRA if configured
if config["model"]["use_lora"]:
    print("   Applying PEFT LoRA...")
    model = apply_peft_lora(model, config)

# Load checkpoint weights
checkpoint_path = './assets/solar_flare_weights.pth'
if os.path.exists(checkpoint_path):
    print(f"\nüì• Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    # Handle different checkpoint formats
    if 'model_state_dict' in checkpoint:
        model_state = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint:
        model_state = checkpoint['state_dict']
    else:
        model_state = checkpoint
    
    # Remove 'module.' prefix if present (from DistributedDataParallel)
    if any(key.startswith('module.') for key in model_state.keys()):
        model_state = {key.replace('module.', ''): value for key, value in model_state.items()}
    
    # Load state dict
    try:
        model.load_state_dict(model_state, strict=True)
        print("‚úÖ Loaded pretrained weights successfully!")
    except Exception as e:
        print(f"‚ö†Ô∏è  Failed to load with strict=True: {e}")
        print("   Trying with strict=False...")
        model.load_state_dict(model_state, strict=False)
        print("‚úÖ Loaded weights (some keys ignored)")
else:
    print(f"‚ö†Ô∏è  Warning: Checkpoint not found at {checkpoint_path}")
    print("   Model will use only foundation weights")

# Move model to device and set to evaluation mode
model = model.to(device)
model.eval()

print(f"\n‚úÖ Model loaded and ready for inference!")
print(f"   Device: {device}")
print(f"   Data type: {config['dtype']}")

In [None]:
# Define 24-hour forecasting function
import datetime
import json
from torch.utils.data import DataLoader, Subset

def forecast_24_hours(model, config, scalers, device):
    """
    Generate 24-hour solar flare probability forecast
    
    Returns:
        dict: Forecast results with timestamps and probabilities
    """
    try:
        # Create dataset
        dataset = SolarFlareDataset(
            sdo_data_root_path=config["data"]["sdo_data_root_path"],
            index_path=config["data"]["valid_data_path"],
            time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
            time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
            n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
            rollout_steps=config["rollout_steps"],
            channels=config["data"]["channels"],
            drop_hmi_probability=config["drop_hmi_probability"],
            num_mask_aia_channels=config["num_mask_aia_channels"],
            use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
            scalers=scalers,
            phase="valid",
            flare_index_path=config["data"]["flare_data_path"],
            pooling=config["data"]["pooling"],
            random_vert_flip=False,
        )
        
        if len(dataset) == 0:
            return {
                'status': 'error',
                'message': 'No data available for forecasting'
            }
        
        # Get samples for forecasting (up to 24 samples or all available)
        num_samples = min(24, len(dataset))
        
        # Use the last samples as they are most recent
        sample_indices = list(range(max(0, len(dataset) - num_samples), len(dataset)))
        
        dataloader = DataLoader(
            dataset=Subset(dataset, sample_indices),
            batch_size=1,
            num_workers=0,  # Use 0 for simplicity in Colab
            pin_memory=True,
            shuffle=False,
            collate_fn=custom_collate_fn,
        )
        
        # Run inference
        base_time = datetime.datetime.now(datetime.timezone.utc)
        forecast_data = []
        
        model.eval()
        with torch.no_grad():
            for hour, (batch, metadata) in enumerate(dataloader):
                if hour >= 24:  # Limit to 24 hours
                    break
                    
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                
                # Get ground truth
                ground_truth = batch["label"].item()
                timestamps_input = metadata["timestamps_input"]
                timestamps_targets = metadata["timestamps_targets"]
                
                # Run inference with mixed precision
                device_type = "cuda" if device.type == "cuda" else "cpu"
                with torch.amp.autocast(device_type=device_type, dtype=config["dtype"]):
                    logits = model(batch)
                    # Convert to probability (sigmoid for binary classification)
                    flare_probability = float(F.sigmoid(logits).item())
                
                # Convert numpy datetime64 to readable string format
                timestamps_target_str = np.datetime_as_string(timestamps_targets, unit='m')[0][0]
                
                # Calculate forecast time
                forecast_time = base_time + datetime.timedelta(hours=hour)
                
                # Determine flare class based on probability
                if flare_probability > 0.7:
                    flare_class = 'M'
                elif flare_probability > 0.4:
                    flare_class = 'C'
                else:
                    flare_class = 'B'
                
                forecast_data.append({
                    'hour': hour,
                    'timestamp': forecast_time.isoformat(),
                    'data_timestamp': timestamps_target_str,
                    'no_flare_probability': float(1.0 - flare_probability),
                    'flare_probability': flare_probability,
                    'flare_class': flare_class,
                    'ground_truth': int(ground_truth)
                })
        
        result = {
            'status': 'success',
            'forecast_generated_at': base_time.isoformat(),
            'forecast_horizon': f'{len(forecast_data)} hours',
            'forecasts': forecast_data,
            'model_info': {
                'name': 'Surya Solar Flare Forecaster',
                'model_type': config['model']['model_type'],
                'version': '1.0',
                'device': str(device),
                'use_lora': config['model']['use_lora']
            }
        }
        
        return result
            
    except Exception as e:
        import traceback
        return {
            'status': 'error',
            'message': str(e),
            'traceback': traceback.format_exc()
        }

# Test the forecasting function
print("üß™ Testing forecast function...")
test_forecast = forecast_24_hours(model, config, scalers, device)

if test_forecast['status'] == 'success':
    print("‚úÖ Forecast function working!")
    print(f"   Generated {len(test_forecast['forecasts'])} hourly predictions")
    print(f"   First hour flare probability: {test_forecast['forecasts'][0]['flare_probability']:.4f}")
    print(f"   Average flare probability: {np.mean([f['flare_probability'] for f in test_forecast['forecasts']]):.4f}")
else:
    print(f"‚ö†Ô∏è  Forecast test failed: {test_forecast['message']}")
    if 'traceback' in test_forecast:
        print(f"\nTraceback:\n{test_forecast['traceback']}")

## üåê Step 5: Setup Flask API Server

In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import secrets

# Generate API key
API_KEY = secrets.token_urlsafe(32)
print(f"üîë API Key: {API_KEY}")
print("‚ö†Ô∏è  Save this key! You'll need it for the local client.\n")

# Initialize Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Authentication middleware
def require_api_key(f):
    def decorated_function(*args, **kwargs):
        api_key = request.headers.get('X-API-Key')
        if api_key != API_KEY:
            return jsonify({'error': 'Invalid or missing API key'}), 401
        return f(*args, **kwargs)
    decorated_function.__name__ = f.__name__
    return decorated_function

# Health check endpoint (no auth required)
@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({
        'status': 'healthy',
        'service': 'Solar Flare Forecasting API',
        'version': '1.0',
        'device': str(device)
    }), 200

# Status endpoint (with auth)
@app.route('/status', methods=['GET'])
@require_api_key
def get_status():
    return jsonify({
        'model_loaded': True,
        'device': str(device),
        'data_available': os.path.exists('./assets'),
        'ready_for_inference': True,
        'model_type': config['model']['model_type'],
        'use_lora': config['model']['use_lora']
    }), 200

# Forecast endpoint (with auth)
@app.route('/forecast', methods=['POST'])
@require_api_key
def generate_forecast_endpoint():
    try:
        print("üìä Generating solar flare forecast...")
        result = forecast_24_hours(model, config, scalers, device)
        
        if result['status'] == 'success':
            print(f"‚úÖ Forecast generated: {len(result['forecasts'])} predictions")
            return jsonify(result), 200
        else:
            print(f"‚ö†Ô∏è  Forecast generation failed: {result['message']}")
            return jsonify(result), 500
            
    except Exception as e:
        import traceback
        print(f"‚ùå Error: {str(e)}")
        return jsonify({
            'status': 'error',
            'message': str(e),
            'traceback': traceback.format_exc()
        }), 500

print("‚úÖ Flask API configured with 3 endpoints:")
print("   GET  /health   - Health check (no auth)")
print("   GET  /status   - System status (auth required)")
print("   POST /forecast - Generate forecast (auth required)")

## üåç Step 6: Start Ngrok Tunnel & API Server

**Important:** After running this cell, copy the ngrok URL displayed below. You'll need it for the local client!

In [None]:
from pyngrok import ngrok
from threading import Thread

# Set ngrok auth token (optional but recommended to avoid timeout)
# Get free token from: https://dashboard.ngrok.com/get-started/your-authtoken
# ngrok.set_auth_token("your_ngrok_token_here")

# Start ngrok tunnel
public_url = ngrok.connect(5000)
print("\n" + "="*70)
print("üåç NGROK TUNNEL ACTIVE")
print("="*70)
print(f"\nüì° Public URL: {public_url}")
print(f"üîë API Key: {API_KEY}")
print("\n" + "="*70)
print("\n‚úÖ Copy the URL and API Key above to use with the local client!")
print("\n‚ö†Ô∏è  Keep this cell running to maintain the tunnel\n")

# Start Flask server in background thread
def run_flask():
    app.run(host='0.0.0.0', port=5000, use_reloader=False)

flask_thread = Thread(target=run_flask, daemon=True)
flask_thread.start()

print("üöÄ API Server is now running!")
print("\nüìù Example API calls:")
print(f"\n   Health Check:")
print(f"   curl {public_url}/health")
print(f"\n   Get Forecast:")
print(f"   curl -X POST {public_url}/forecast \\")
print(f"        -H 'X-API-Key: {API_KEY}'")

# Keep the cell running
print("\n‚è≥ Server running... Press 'Stop' button to terminate")
import time
try:
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    print("\nüõë Server stopped")

## üß™ Step 7: Test API (Optional)

You can test the API directly from Colab before using the local client.

In [None]:
import requests

# Test health endpoint
print("üß™ Testing API endpoints...\n")

ngrok_url = str(public_url)  # Use the URL from previous cell

# Test 1: Health check
print("1. Health Check:")
response = requests.get(f"{ngrok_url}/health")
print(f"   Status: {response.status_code}")
print(f"   Response: {response.json()}\n")

# Test 2: Status (with auth)
print("2. Status Check:")
headers = {'X-API-Key': API_KEY}
response = requests.get(f"{ngrok_url}/status", headers=headers)
print(f"   Status: {response.status_code}")
print(f"   Response: {response.json()}\n")

# Test 3: Generate forecast
print("3. Generate Forecast:")
response = requests.post(f"{ngrok_url}/forecast", headers=headers)
print(f"   Status: {response.status_code}")

if response.status_code == 200:
    forecast = response.json()
    print(f"   ‚úÖ Forecast generated successfully!")
    print(f"   Generated at: {forecast['forecast_generated_at']}")
    print(f"   Number of predictions: {len(forecast['forecasts'])}")
    print(f"\n   Sample prediction (Hour 0):")
    print(f"   {forecast['forecasts'][0]}")
else:
    print(f"   ‚ùå Error: {response.json()}")

---

## üìö Next Steps

1. **Save the ngrok URL and API Key** from Step 6
2. **Download the local client script** (`local_api_client.py`)
3. **Update the client configuration** with your URL and API key
4. **Run forecasts from your local machine!**

### Keeping the API Running
- The API will stay active as long as this Colab notebook is running
- Free Colab sessions timeout after ~12 hours of inactivity
- Consider Colab Pro for longer sessions

### Troubleshooting
- If the tunnel stops working, just re-run Step 6
- Make sure to update your local client with the new URL and API key

---

**Built with ‚ù§Ô∏è using NASA IMPACT's Surya Foundation Model**