# Exploratory Data Analysis: Solar Wind & Kp Index
This notebook performs EDA on the Aurora project data pipeline.

**Goal:** Understand the distribution, missing values, and correlations of Solar Wind (OMNI) features before training the model.

In [None]:
# Standard Data Science Packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import missingno as msno
from statsmodels.stats.outliers_influence import variance_inflation_factor

# Project Specific Imports
import logging
from pathlib import Path
from src.data import (
    fetch_omni_data,
    fetch_kp_range,
    clean_solarwind,
    add_time_features,
    add_moving_averages
)

# Configuration
%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid') # Or 'ggplot'
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# 1. Load & Merge Data

In [None]:
# Parameters (matching default args in build_dataset.py)
START_YEAR = 2010
END_YEAR = 2020

logging.info(f"Fetching historical data from {START_YEAR} to {END_YEAR}...")

# 1. Fetch Solar Wind Data (OMNI)
try:
    sw_df = fetch_omni_data(start_year=START_YEAR, end_year=END_YEAR)
    logging.info(f"Solar Wind Data shape: {sw_df.shape}")
except Exception as e:
    logging.error(f"Error fetching OMNI data: {e}")
    # Fallback if fetch fails (e.g., load from cache if you have one)
    # sw_df = pd.read_pickle("data/interim/omni_cached.pkl")

# 2. Fetch Kp Index
try:
    kp_df = fetch_kp_range(start_year=START_YEAR, end_year=END_YEAR)
    logging.info(f"Kp Data shape: {kp_df.shape}")
except Exception as e:
    logging.error(f"Error fetching Kp data: {e}")

# 3. Merge (Simple time-based merge assumption, adjust based on actual src.data logic)
# Assuming indices are datetime objects
df = pd.merge(sw_df, kp_df, left_index=True, right_index=True, how='inner')
logging.info(f"Merged Data shape: {df.shape}")

df.head()

# 2. Data Overview & Missing Values
Look at the raw data quality before cleaning.

In [None]:
print("Data Info:")
print(df.info())

print("\nSummary Statistics:")
display(df.describe())

In [None]:
# Using missingno to visualize gaps in the time series
plt.figure(figsize=(15, 6))
msno.matrix(df, freq='M') # 'M' resamples to Month for readability if data is high-res
plt.title("Missing Value Matrix (Time Series Gaps)", fontsize=16)
plt.show()

# Percentage of missing values per column
missing_percent = df.isnull().mean() * 100
missing_percent = missing_percent[missing_percent > 0].sort_values(ascending=False)

if not missing_percent.empty:
    plt.figure(figsize=(10, 5))
    sns.barplot(x=missing_percent.index, y=missing_percent.values, palette='viridis')
    plt.title("Percentage of Missing Data by Feature")
    plt.ylabel("Percent Missing (%)")
    plt.xticks(rotation=45)
    plt.show()
else:
    print("No missing values found (or already handled in fetch step).")

# 3. Distributions & Outliers
Checking the distribution of key physical parameters (e.g., `bz`, `speed`, `density`).

In [None]:
# Identify numerical columns (excluding Kp usually, which is ordinal/discrete-like)
num_cols = df.select_dtypes(include=[np.number]).columns.tolist()

# Plot histograms
df[num_cols].hist(bins=50, figsize=(20, 15), layout=(5, 4), color='steelblue', edgecolor='black')
plt.suptitle("Feature Distributions", fontsize=20)
plt.show()

# 4. Feature Engineering
We apply the pipeline steps `clean_solarwind`, `add_time_features`, and `add_moving_averages` to see how features change.

In [None]:
# 1. Clean
df_clean = clean_solarwind(df)

# 2. Add Time Features (Cyclical)
df_features = add_time_features(df_clean)

# 3. Add Moving Averages (This likely adds rolling means/stds)
df_final = add_moving_averages(df_features)

print(f"Shape after feature engineering: {df_final.shape}")
df_final.head()

# 5. Correlation Analysis
We look at how features correlate with the target (`kp`).

In [None]:
def analyze_correlations(df, target_col='kp'):
    # Select numerical columns
    numeric_df = df.select_dtypes(include=[np.number])
    
    # 1. Correlation with Target
    corr_with_target = numeric_df.corrwith(numeric_df[target_col]).sort_values(ascending=False)
    
    plt.figure(figsize=(10, 8))
    sns.barplot(x=corr_with_target.values, y=corr_with_target.index, palette='coolwarm')
    plt.title(f"Feature Correlation with Target ({target_col})")
    plt.show()

    # 2. Full Heatmap
    plt.figure(figsize=(20, 16))
    sns.heatmap(numeric_df.corr(), annot=False, cmap='coolwarm', center=0)
    plt.title("Global Correlation Matrix")
    plt.show()

    # 3. VIF (Variance Inflation Factor) check for Multicollinearity
    # (Handling infinite values or NaNs first if present)
    clean_num = numeric_df.dropna().replace([np.inf, -np.inf], np.nan).dropna()
    
    # Selecting a subset if too many columns (e.g., top 15 correlated)
    top_cols = corr_with_target.head(15).index.tolist()
    X = clean_num[top_cols].drop(columns=[target_col], errors='ignore')
    
    vif_data = pd.DataFrame()
    vif_data['Feature'] = X.columns
    vif_data['VIF'] = [variance_inflation_factor(X.values, i) for i in range(len(X.columns))]
    
    print("\nVariance Inflation Factor (Top Correlated Features):")
    display(vif_data.sort_values(by='VIF', ascending=False))

# Assuming 'kp' or 'Kp' is the target column name
target_variable = 'kp' if 'kp' in df_final.columns else df_final.columns[-1]
analyze_correlations(df_final, target_col=target_variable)

# 6. Time Series Visualization
For Aurora data we need to visualize the time series, specifically identifying storm events (high Kp).

In [None]:
def plot_storm_event(df, start_date, end_date):
    """Plots Solar Wind parameters and Kp index for a specific date range."""
    subset = df[start_date:end_date]
    
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(14, 10), sharex=True)
    
    # Plot Kp
    if 'kp' in subset.columns:
        sns.lineplot(data=subset, x=subset.index, y='kp', ax=axes[0], color='red', linewidth=2)
        axes[0].set_ylabel("Kp Index")
        axes[0].set_title(f"Geomagnetic Storm Event: {start_date} to {end_date}")
        axes[0].axhline(y=5, color='black', linestyle='--', label='Storm Threshold')
        axes[0].legend()

    # Plot Bz (Interplanetary Magnetic Field) - Critical for storms
    # Adjust column name 'bz' based on your actual data schema
    bz_col = [c for c in subset.columns if 'bz' in c.lower()]
    if bz_col:
        sns.lineplot(data=subset, x=subset.index, y=bz_col[0], ax=axes[1], color='blue')
        axes[1].set_ylabel("Bz (nT)")
        axes[1].axhline(y=0, color='gray', linestyle='--', linewidth=0.5)

    # Plot Speed/Density
    # Adjust column names 'speed', 'density' based on schema
    speed_col = [c for c in subset.columns if 'speed' in c.lower()]
    if speed_col:
        sns.lineplot(data=subset, x=subset.index, y=speed_col[0], ax=axes[2], color='green')
        axes[2].set_ylabel("Solar Wind Speed (km/s)")

    plt.tight_layout()
    plt.show()

# Example: St. Patrick's Day Storm 2015
plot_storm_event(df_final, '2015-03-15', '2015-03-20')