# Solar Flare Analysis: Validation Against Known Catalogs

This notebook demonstrates how to validate our flare detection methods against known solar flare catalogs from official sources like NOAA SWPC.

## Setup and Imports

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta

# Add the project root to the path
project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.append(project_root)

# Import project modules
from config import settings
from src.data_processing.data_loader import load_goes_data, preprocess_xrs_data, remove_background
from src.flare_detection.traditional_detection import (
    detect_flare_peaks, define_flare_bounds, detect_overlapping_flares
)
from src.ml_models.flare_decomposition import FlareDecompositionModel, reconstruct_flares
from src.validation.catalog_validation import (
    download_noaa_flare_catalog, compare_detected_flares, 
    calculate_detection_quality, get_flare_class_distribution
)

## Downloading Flare Catalogs

First, let's download the NOAA SWPC flare catalog for a specific date range. This catalog contains officially reported solar flare events.

In [None]:
# Define date range for validation
start_date = '2022-06-01'
end_date = '2022-06-30'

# Create a directory for catalogs if it doesn't exist
catalog_dir = os.path.join('..', 'data', 'catalogs')
os.makedirs(catalog_dir, exist_ok=True)

# Download NOAA SWPC flare catalog
catalog_file = os.path.join(catalog_dir, f'noaa_flares_{start_date}_to_{end_date}.csv')

try:
    if os.path.exists(catalog_file):
        print(f"Loading existing catalog from {catalog_file}")
        catalog_flares = pd.read_csv(catalog_file, parse_dates=['start_time', 'end_time'])
    else:
        print(f"Downloading NOAA SWPC flare catalog from {start_date} to {end_date}")
        catalog_flares = download_noaa_flare_catalog(start_date, end_date, output_file=catalog_file)
    
    # Display the catalog
    print(f"Downloaded {len(catalog_flares)} flare events")
    display(catalog_flares.head())
except Exception as e:
    print(f"Error downloading catalog: {e}")
    catalog_flares = pd.DataFrame()

## Analyzing Catalog Flares

Let's analyze the distribution of flare classes in the catalog:

In [None]:
if not catalog_flares.empty:
    # Get flare class distribution
    class_distribution = get_flare_class_distribution(catalog_flares)
    
    # Plot distribution
    plt.figure(figsize=(10, 6))
    class_distribution.plot(kind='bar', color='skyblue')
    plt.xlabel('Flare Class')
    plt.ylabel('Number of Flares')
    plt.title('Distribution of Flare Classes in NOAA Catalog')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    print("\nSummary Statistics for Flares in Catalog:")
    print(f"Total flares: {len(catalog_flares)}")
    print(f"Date range: {catalog_flares['start_time'].min().date()} to {catalog_flares['start_time'].max().date()}")
    
    # Duration statistics
    catalog_flares['duration_minutes'] = (catalog_flares['end_time'] - catalog_flares['start_time']).dt.total_seconds() / 60
    print(f"Average flare duration: {catalog_flares['duration_minutes'].mean():.2f} minutes")
    print(f"Median flare duration: {catalog_flares['duration_minutes'].median():.2f} minutes")
    print(f"Maximum flare duration: {catalog_flares['duration_minutes'].max():.2f} minutes")
    
    # Plot flare durations by class
    plt.figure(figsize=(12, 6))
    sns.boxplot(x='flare_class', y='duration_minutes', data=catalog_flares)
    plt.xlabel('Flare Class')
    plt.ylabel('Duration (minutes)')
    plt.title('Flare Durations by Class')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

## Loading GOES XRS Data

Now, let's load the GOES XRS data for the same time period:

In [None]:
import re

# Locate data files for the validation period
data_dir = settings.DATA_DIR

# Look for files that match our date range
data_files = []
for file in os.listdir(data_dir):
    if file.endswith('.nc'):
        # Try to extract date from filename
        try:
            # GOES files often have dates in the format YYYYMMDD
            date_match = re.search(r'\d{8}', file)
            if date_match:
                file_date = datetime.strptime(date_match.group(), '%Y%m%d').date()
                # Check if date is in our range
                start_date_obj = datetime.strptime(start_date, '%Y-%m-%d').date()
                end_date_obj = datetime.strptime(end_date, '%Y-%m-%d').date()
                if start_date_obj <= file_date <= end_date_obj:
                    data_files.append(file)
        except Exception:
            # If date extraction fails, include the file anyway
            data_files.append(file)

print(f"Found {len(data_files)} data files matching date range")

# If no specific files found, use whatever is available
if not data_files:
    data_files = [f for f in os.listdir(data_dir) if f.endswith('.nc')]
    print(f"Using {len(data_files)} available data files")

if data_files:
    # Process each file
    all_flares = []
    
    for data_file in data_files[:3]:  # Limit to first 3 files for demo
        file_path = os.path.join(data_dir, data_file)
        print(f"\nProcessing {data_file}...")
        
        # Load data
        data = load_goes_data(file_path)
        if data is None:
            print(f"Failed to load {data_file}. Skipping.")
            continue
        
        # Preprocess B channel data
        channel = 'B'
        flux_col = f'xrs{channel.lower()}'
        df = preprocess_xrs_data(data, channel=channel, remove_bad_data=True, interpolate_gaps=True)
        
        # Remove background
        df_bg = remove_background(
            df, 
            window_size=settings.BACKGROUND_PARAMS['window_size'],
            quantile=settings.BACKGROUND_PARAMS['quantile']
        )
        
        # Detect flares
        peaks = detect_flare_peaks(
            df, flux_col,
            threshold_factor=settings.DETECTION_PARAMS['threshold_factor'],
            window_size=settings.DETECTION_PARAMS['window_size']
        )
        
        flares = define_flare_bounds(
            df, flux_col, peaks['peak_index'].values,
            start_threshold=settings.DETECTION_PARAMS['start_threshold'],
            end_threshold=settings.DETECTION_PARAMS['end_threshold'],
            min_duration=settings.DETECTION_PARAMS['min_duration'],
            max_duration=settings.DETECTION_PARAMS['max_duration']
        )
        
        # Add file information
        flares['source_file'] = data_file
        
        # Add to list
        all_flares.append(flares)
        
        print(f"Detected {len(flares)} flares in {data_file}")
    
    # Combine all detected flares
    if all_flares:
        detected_flares = pd.concat(all_flares, ignore_index=True)
        print(f"\nTotal detected flares: {len(detected_flares)}")
        display(detected_flares.head())
    else:
        print("No flares detected")
        detected_flares = pd.DataFrame()
else:
    print("No data files found")
    detected_flares = pd.DataFrame()

## Comparing Detected Flares with Catalog

Now let's compare our detected flares with the catalog flares:

In [None]:
if not detected_flares.empty and not catalog_flares.empty:
    # Make sure both dataframes have the required columns
    required_columns = ['start_time', 'peak_time', 'end_time', 'peak_flux']
    
    if all(col in detected_flares.columns for col in required_columns) and \
       all(col in catalog_flares.columns for col in required_columns):
        
        # Filter catalog flares to match the time range of detected flares
        min_time = detected_flares['start_time'].min()
        max_time = detected_flares['end_time'].max()
        
        filtered_catalog = catalog_flares[
            (catalog_flares['end_time'] >= min_time) & 
            (catalog_flares['start_time'] <= max_time)
        ]
        
        print(f"Comparing {len(detected_flares)} detected flares with {len(filtered_catalog)} catalog flares")
        
        # Compare flares
        comparison = compare_detected_flares(
            detected_flares, 
            filtered_catalog, 
            time_tolerance='10min'  # Allow 10 minutes time difference
        )
        
        # Display results
        metrics = comparison['metrics']
        print("\nValidation Metrics:")
        print(f"True Positives: {metrics['true_positives']}")
        print(f"False Positives: {metrics['false_positives']}")
        print(f"False Negatives: {metrics['false_negatives']}")
        print(f"Precision: {metrics['precision']:.3f}")
        print(f"Recall: {metrics['recall']:.3f}")
        print(f"F1 Score: {metrics['f1_score']:.3f}")
        
        # Additional quality metrics
        quality_metrics = calculate_detection_quality(comparison)
        if 'mean_time_diff' in quality_metrics:
            print(f"\nAverage time difference: {quality_metrics['mean_time_diff']:.2f} minutes")
        if 'mean_flux_ratio' in quality_metrics:
            print(f"Average flux ratio (detected/catalog): {quality_metrics['mean_flux_ratio']:.2f}")
        
        # Display matched flares
        print("\nSample of matched flares:")
        display(comparison['matched_flares'].head())
        
        # Plot precision-recall by flare class
        if len(comparison['matched_flares']) > 0:
            # Add flare class to matched flares
            matched = comparison['matched_flares']
            
            # Create a list of all classes
            flare_classes = ['A', 'B', 'C', 'M', 'X']
            
            # Calculate metrics by class
            class_metrics = []
            
            for flare_class in flare_classes:
                # Get catalog flares of this class
                class_catalog = filtered_catalog[filtered_catalog['flare_class'] == flare_class]
                
                # Find corresponding detected flares
                class_matched_idx = matched[matched['catalog_idx'].isin(class_catalog.index)]
                
                # Count TP, FP, FN
                tp = len(class_matched_idx)
                fn = len(class_catalog) - tp
                
                # Calculate metrics
                precision = tp / (tp + 0) if tp > 0 else 0  # No false positives by class
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                
                class_metrics.append({
                    'class': flare_class,
                    'precision': precision,
                    'recall': recall,
                    'count': len(class_catalog)
                })
            
            # Convert to DataFrame
            class_metrics_df = pd.DataFrame(class_metrics)
            
            # Plot metrics by class
            fig, ax1 = plt.subplots(figsize=(10, 6))
            
            # Bar plot for counts
            ax2 = ax1.twinx()
            ax2.bar(class_metrics_df['class'], class_metrics_df['count'], 
                   alpha=0.3, color='gray', label='Flare Count')
            ax2.set_ylabel('Number of Flares')
            
            # Line plots for precision and recall
            ax1.plot(class_metrics_df['class'], class_metrics_df['precision'], 
                    'o-', color='blue', label='Precision')
            ax1.plot(class_metrics_df['class'], class_metrics_df['recall'], 
                    'o-', color='red', label='Recall')
            
            ax1.set_xlabel('Flare Class')
            ax1.set_ylabel('Score')
            ax1.set_ylim(0, 1.05)
            ax1.grid(True, linestyle='--', alpha=0.7)
            
            # Combine legends
            lines1, labels1 = ax1.get_legend_handles_labels()
            lines2, labels2 = ax2.get_legend_handles_labels()
            ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower right')
            
            plt.title('Detection Performance by Flare Class')
            plt.tight_layout()
            plt.show()
            
            # Display class metrics
            display(class_metrics_df)
    else:
        missing_columns = [col for col in required_columns if col not in detected_flares.columns]
        missing_columns += [col for col in required_columns if col not in catalog_flares.columns]
        print(f"Cannot compare flares. Missing columns: {missing_columns}")
else:
    print("Cannot compare flares: either no detected flares or no catalog flares available")

## Analyzing Detection Performance

Let's analyze the detection performance in more detail, especially for flares that were missed or falsely detected:

In [None]:
if 'comparison' in locals() and all(key in comparison for key in ['unmatched_detected', 'unmatched_catalog']):
    # Analyze unmatched detected flares (false positives)
    false_positives = comparison['unmatched_detected']
    if len(false_positives) > 0:
        print(f"\nAnalysis of {len(false_positives)} False Positives (detected but not in catalog):")
        
        # Calculate statistics
        if 'peak_flux' in false_positives.columns:
            print(f"Mean peak flux: {false_positives['peak_flux'].mean():.2e} W/m²")
            print(f"Median peak flux: {false_positives['peak_flux'].median():.2e} W/m²")
            print(f"Max peak flux: {false_positives['peak_flux'].max():.2e} W/m²")
        
        if 'start_time' in false_positives.columns and 'end_time' in false_positives.columns:
            false_positives['duration'] = (false_positives['end_time'] - false_positives['start_time']).dt.total_seconds() / 60
            print(f"Mean duration: {false_positives['duration'].mean():.2f} minutes")
            print(f"Median duration: {false_positives['duration'].median():.2f} minutes")
        
        # Distribution of false positives by peak flux
        if 'peak_flux' in false_positives.columns:
            plt.figure(figsize=(10, 6))
            plt.hist(false_positives['peak_flux'], bins=20, alpha=0.7, log=True)
            plt.xlabel('Peak Flux (W/m²)')
            plt.ylabel('Number of False Positives')
            plt.title('Distribution of False Positive Flares by Peak Flux')
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.xscale('log')
            plt.tight_layout()
            plt.show()
    
    # Analyze unmatched catalog flares (false negatives)
    false_negatives = comparison['unmatched_catalog']
    if len(false_negatives) > 0:
        print(f"\nAnalysis of {len(false_negatives)} False Negatives (in catalog but not detected):")
        
        # Distribution by flare class
        if 'flare_class' in false_negatives.columns:
            missed_class_dist = false_negatives['flare_class'].value_counts().sort_index()
            print("\nMissed flares by class:")
            for cls, count in missed_class_dist.items():
                print(f"  {cls}: {count}")
            
            # Plot distribution
            plt.figure(figsize=(8, 5))
            missed_class_dist.plot(kind='bar', color='tomato')
            plt.xlabel('Flare Class')
            plt.ylabel('Number of Missed Flares')
            plt.title('Distribution of Missed Flares by Class')
            plt.grid(axis='y', linestyle='--', alpha=0.7)
            plt.tight_layout()
            plt.show()
        
        # Duration of missed flares
        if 'start_time' in false_negatives.columns and 'end_time' in false_negatives.columns:
            false_negatives['duration'] = (false_negatives['end_time'] - false_negatives['start_time']).dt.total_seconds() / 60
            print(f"\nMean duration of missed flares: {false_negatives['duration'].mean():.2f} minutes")
            print(f"Median duration of missed flares: {false_negatives['duration'].median():.2f} minutes")
            
            # Plot duration distribution
            plt.figure(figsize=(10, 6))
            sns.histplot(data=false_negatives, x='duration', hue='flare_class' if 'flare_class' in false_negatives.columns else None,
                        kde=True, bins=20)
            plt.xlabel('Duration (minutes)')
            plt.ylabel('Number of Missed Flares')
            plt.title('Duration Distribution of Missed Flares')
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.tight_layout()
            plt.show()

## Validating ML-based Flare Separation

Finally, let's validate our ML-based flare separation approach against overlapping flares:

In [None]:
if 'detected_flares' in locals() and not detected_flares.empty:
    # Detect overlapping flares
    overlapping = detect_overlapping_flares(detected_flares, min_overlap='2min')
    print(f"Detected {len(overlapping)} potentially overlapping flare pairs")
    
    if overlapping:
        # Load the ML model
        print("\nLoading ML model for flare separation...")
        model = FlareDecompositionModel(
            sequence_length=settings.ML_PARAMS['sequence_length'],
            n_features=settings.ML_PARAMS['n_features'],
            max_flares=settings.ML_PARAMS['max_flares']
        )
        model.build_model()
        
        # Try to load the pre-trained model
        model_path = os.path.join(settings.MODEL_DIR, 'flare_decomposition_model')
        model_loaded = False
        
        try:
            model.load_model(model_path)
            print(f"Successfully loaded model from {model_path}")
            model_loaded = True
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Training a simple model with synthetic data...")
            
            # Generate synthetic training data
            X_train, y_train = model.generate_synthetic_data(n_samples=500, noise_level=0.05)
            X_val, y_val = model.generate_synthetic_data(n_samples=100, noise_level=0.05)
            
            # Train model
            history = model.train(
                X_train, y_train,
                validation_data=(X_val, y_val),
                epochs=10,  # Using fewer epochs for demonstration
                batch_size=settings.ML_PARAMS['batch_size'],
                save_path=model_path
            )
            model_loaded = True
        
        if model_loaded and 'df' in locals():
            # Process overlapping flares
            print("\nProcessing overlapping flare pairs:")
            
            # Choose one overlapping pair for demonstration
            i, j, duration = overlapping[0]
            print(f"Analyzing overlap between flares {i+1} and {j+1} (overlap: {duration})")
            
            # Extract the time series segment
            start_idx = min(detected_flares.iloc[i]['start_index'], detected_flares.iloc[j]['start_index'])
            end_idx = max(detected_flares.iloc[i]['end_index'], detected_flares.iloc[j]['end_index'])
            
            # Ensure we have enough context around the flares
            padding = settings.ML_PARAMS['sequence_length'] // 4
            start_idx = max(0, start_idx - padding)
            end_idx = min(len(df) - 1, end_idx + padding)
            
            # Extract the time series segment
            segment = df.iloc[start_idx:end_idx][flux_col].values
            
            # Ensure the segment has the required length for the model
            if len(segment) < settings.ML_PARAMS['sequence_length']:
                # Pad if too short
                segment = np.pad(segment, 
                                (0, settings.ML_PARAMS['sequence_length'] - len(segment)), 
                                'constant')
            elif len(segment) > settings.ML_PARAMS['sequence_length']:
                # Truncate if too long
                segment = segment[:settings.ML_PARAMS['sequence_length']]
            
            # Reshape for model input
            segment = segment.reshape(1, -1, 1)
            
            # Decompose the flares
            original, individual_flares, combined = reconstruct_flares(
                model, segment, window_size=settings.ML_PARAMS['sequence_length'], plot=True
            )
            plt.tight_layout()
            plt.show()
            
            # Calculate energy for each separated flare
            print("\nEnergy estimates for separated flares:")
            energies = []
            
            for k in range(individual_flares.shape[1]):
                if np.max(individual_flares[:, k]) > 0.05 * np.max(original):
                    energy = np.trapz(individual_flares[:, k])
                    energies.append(energy)
                    print(f"  Flare component {k+1}: {energy:.4e}")
            
            # Compare with original flares
            print("\nComparing with original flares:")
            for idx, flare_idx in enumerate([i, j]):
                flare = detected_flares.iloc[flare_idx]
                start = flare['start_time']
                end = flare['end_time']
                peak_flux = flare['peak_flux']
                print(f"  Original flare {flare_idx+1}: start={start}, end={end}, peak_flux={peak_flux:.2e}")
            
            # Pseudo-validation: check if the sum of separated energies is close to the sum of original energies
            if len(energies) > 0 and i < len(detected_flares) and j < len(detected_flares):
                try:
                    # Calculate energy for original flares (crude approximation)
                    flare1 = detected_flares.iloc[i]
                    flare2 = detected_flares.iloc[j]
                    
                    # Crude energy estimate based on triangle approximation
                    orig_energy1 = 0.5 * flare1['peak_flux'] * ((flare1['end_time'] - flare1['start_time']).total_seconds() / 60)
                    orig_energy2 = 0.5 * flare2['peak_flux'] * ((flare2['end_time'] - flare2['start_time']).total_seconds() / 60)
                    
                    print(f"\nSum of original energies (crude estimate): {orig_energy1 + orig_energy2:.4e}")
                    print(f"Sum of separated energies: {sum(energies):.4e}")
                    
                    # Calculate ratio
                    ratio = sum(energies) / (orig_energy1 + orig_energy2)
                    print(f"Ratio (separated/original): {ratio:.2f}")
                    
                    if 0.5 <= ratio <= 2.0:
                        print("Energy conservation is reasonable")
                    else:
                        print("Energy conservation may be an issue")
                except Exception as e:
                    print(f"Error in energy comparison: {e}")
    else:
        print("No overlapping flares to analyze")
else:
    print("No detected flares available for overlapping analysis")

## Summary

In this notebook, we've explored how to validate our flare detection and separation methods against known flare catalogs. We've:

1. Downloaded and analyzed the NOAA SWPC flare catalog
2. Compared our detected flares with the catalog entries
3. Calculated detection performance metrics (precision, recall, F1-score)
4. Analyzed detection performance by flare class
5. Investigated false positives and false negatives
6. Validated our ML-based flare separation approach on overlapping flares

This validation process helps us understand the strengths and limitations of our methods and suggests areas for improvement. The results can be used to refine detection parameters, improve the ML model, or develop hybrid approaches that combine the best aspects of multiple methods.