# Solar Flare Analysis: ML-Based Flare Decomposition

This notebook demonstrates how to train and use a neural network model to separate overlapping solar flares.

## Setup and Imports

In [None]:
%pip install tensorflow

import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import layers, models
from sklearn.model_selection import train_test_split

# Add the project root to the path
project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.append(project_root)

# Import project modules
from config import settings
from src.data_processing.data_loader import load_goes_data, preprocess_xrs_data
from src.flare_detection.traditional_detection import (
    detect_flare_peaks, define_flare_bounds, detect_overlapping_flares
)
from src.ml_models.flare_decomposition import (
    FlareDecompositionModel, reconstruct_flares
)
from src.visualization.plotting import plot_flare_decomposition

## Understanding the ML Model Architecture

Our flare decomposition model is designed to take a time series containing overlapping flares as input and separate it into individual flare components. The model uses a encoder-decoder architecture with LSTM layers to handle the temporal nature of the data.

In [None]:
# Initialize the model
model = FlareDecompositionModel(
    sequence_length=settings.ML_PARAMS['sequence_length'],
    n_features=settings.ML_PARAMS['n_features'],
    max_flares=settings.ML_PARAMS['max_flares'],
    dropout_rate=settings.ML_PARAMS['dropout_rate']
)

# Build the model
model.build_model()

# Print model summary
model.model.summary()

# Display the model architecture
from keras.utils import plot_model
try:
    plot_model(model.model, to_file='model_architecture.png', show_shapes=True, show_dtype=True)
    from IPython.display import Image
    Image('model_architecture.png')
except Exception as e:
    print(f"Couldn't generate model visualization: {e}")
    print("You may need to install graphviz and pydot packages.")

## Generating Synthetic Training Data

Since labeled data for overlapping solar flares is rare, we'll generate synthetic data for training:

In [None]:
# Generate synthetic data
X_train, y_train = model.generate_synthetic_data(n_samples=1000, noise_level=0.05)
X_val, y_val = model.generate_synthetic_data(n_samples=200, noise_level=0.05)

print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")

# Visualize a few synthetic examples
n_examples = 3
plt.figure(figsize=(15, 12))

for i in range(n_examples):
    # Original combined signal
    plt.subplot(n_examples, 2, i*2 + 1)
    plt.plot(X_train[i, :, 0])
    plt.title(f'Example {i+1}: Combined Signal')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Individual components
    plt.subplot(n_examples, 2, i*2 + 2)
    for j in range(y_train.shape[2]):
        plt.plot(y_train[i, :, j], label=f'Flare {j+1}')
    plt.title(f'Example {i+1}: Individual Flares')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()

plt.tight_layout()
plt.show()

## Training the Model

Now let's train our model on the synthetic data:

In [None]:
# Train the model
history = model.train(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=settings.ML_PARAMS['epochs'],
    batch_size=settings.ML_PARAMS['batch_size'],
    save_path=os.path.join(settings.MODEL_DIR, 'flare_decomposition_model')
)

# Plot training history
history_fig = model.plot_training_history()
plt.show()

## Evaluating the Model

Let's evaluate our model on a separate test set:

In [None]:
# Generate test data
X_test, y_test = model.generate_synthetic_data(n_samples=100, noise_level=0.08)

# Evaluate the model
eval_results = model.evaluate(X_test, y_test)
print(f"Test loss: {eval_results[0]:.4f}")
print(f"Test MAE: {eval_results[1]:.4f}")

## Visualizing Model Predictions

Let's see how well our model can decompose overlapping flares on the test data:

In [None]:
# Get model predictions
predictions = model.model.predict(X_test)

# Visualize predictions for a few examples
n_examples = 3
plt.figure(figsize=(15, 15))

for i in range(n_examples):
    # Original signal
    plt.subplot(n_examples, 3, i*3 + 1)
    plt.plot(X_test[i, :, 0])
    plt.title(f'Example {i+1}: Original Signal')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Ground truth individual components
    plt.subplot(n_examples, 3, i*3 + 2)
    for j in range(y_test.shape[2]):
        plt.plot(y_test[i, :, j], label=f'True Flare {j+1}')
    plt.title(f'Example {i+1}: True Components')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Predicted individual components
    plt.subplot(n_examples, 3, i*3 + 3)
    for j in range(predictions.shape[2]):
        plt.plot(predictions[i, :, j], linestyle='--', label=f'Predicted Flare {j+1}')
    plt.title(f'Example {i+1}: Predicted Components')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()

plt.tight_layout()
plt.show()

## Working with Real Data

Now let's apply our model to real GOES XRS data to separate overlapping flares:

In [None]:
# Locate and load sample data
data_dir = settings.DATA_DIR
sample_files = [f for f in os.listdir(data_dir) if f.endswith('.nc')]

if sample_files:
    data_file = os.path.join(data_dir, sample_files[0])
    print(f"Using {data_file} for demonstration")
    
    # Load and preprocess data
    data = load_goes_data(data_file)
    channel = 'B'
    flux_col = f'xrs{channel.lower()}'
    df = preprocess_xrs_data(data, channel=channel, remove_bad_data=True, interpolate_gaps=True)
    
    # Detect flares using traditional method
    peaks = detect_flare_peaks(
        df, flux_col,
        threshold_factor=settings.DETECTION_PARAMS['threshold_factor'],
        window_size=settings.DETECTION_PARAMS['window_size']
    )
    
    flares = define_flare_bounds(
        df, flux_col, peaks['peak_index'].values,
        start_threshold=settings.DETECTION_PARAMS['start_threshold'],
        end_threshold=settings.DETECTION_PARAMS['end_threshold'],
        min_duration=settings.DETECTION_PARAMS['min_duration'],
        max_duration=settings.DETECTION_PARAMS['max_duration']
    )
    
    # Detect overlapping flares
    overlapping = detect_overlapping_flares(flares, min_overlap='2min')
    
    print(f"Detected {len(overlapping)} potentially overlapping flare pairs")
    
    # Process overlapping flares
    if overlapping:
        print("\nOverlapping flare pairs:")
        for i, j, duration in overlapping:
            print(f"  Flares {i+1} and {j+1} overlap by {duration}")
            
            # Extract the time series segment
            start_idx = min(flares.iloc[i]['start_index'], flares.iloc[j]['start_index'])
            end_idx = max(flares.iloc[i]['end_index'], flares.iloc[j]['end_index'])
            
            # Ensure we have enough context around the flares
            padding = settings.ML_PARAMS['sequence_length'] // 4
            start_idx = max(0, start_idx - padding)
            end_idx = min(len(df) - 1, end_idx + padding)
            
            # Extract the time series segment
            segment = df.iloc[start_idx:end_idx][flux_col].values
            
            # Ensure the segment has the right length
            if len(segment) < settings.ML_PARAMS['sequence_length']:
                segment = np.pad(segment, 
                                (0, settings.ML_PARAMS['sequence_length'] - len(segment)), 
                                'constant')
            elif len(segment) > settings.ML_PARAMS['sequence_length']:
                segment = segment[:settings.ML_PARAMS['sequence_length']]
            
            # Reshape for model input
            segment = segment.reshape(1, -1, 1)
            
            # Decompose the flares
            original, individual_flares, combined = reconstruct_flares(
                model, segment, window_size=settings.ML_PARAMS['sequence_length'], plot=False
            )
            
            # Plot the decomposition
            timestamps = df.index[start_idx:start_idx+len(segment.flatten())]
            fig = plot_flare_decomposition(original.flatten(), individual_flares, timestamps)
            plt.tight_layout()
            plt.show()
            
            # Calculate energy for each separated flare
            print("\nEnergy estimates for separated flares:")
            for k in range(individual_flares.shape[1]):
                if np.max(individual_flares[:, k]) > 0.05 * np.max(original):
                    energy = np.trapz(individual_flares[:, k])
                    print(f"  Flare component {k+1}: {energy:.4e}")
    else:
        print("No overlapping flares detected in the data.")
else:
    print("No .nc files found. Please place GOES XRS data in the 'data' directory.")

## Saving and Loading Models

For future use, you might want to save the trained model:

In [None]:
# Save the model
model_path = os.path.join(settings.MODEL_DIR, 'flare_decomposition_notebook')
model.save_model(model_path)
print(f"Model saved to {model_path}")

# To load the model later
new_model = FlareDecompositionModel(
    sequence_length=settings.ML_PARAMS['sequence_length'],
    n_features=settings.ML_PARAMS['n_features'],
    max_flares=settings.ML_PARAMS['max_flares']
)
new_model.build_model()

try:
    new_model.load_model(model_path)
    print("Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")

## Summary

In this notebook, we've demonstrated:

1. How to build a neural network model for separating overlapping solar flares
2. Generating synthetic training data for model training
3. Training and evaluating the flare decomposition model
4. Applying the trained model to real GOES XRS data to separate overlapping flares
5. Saving and loading the trained model for future use

In the next notebook, we'll analyze the power-law properties of flare energy distributions.