In [None]:
"""
Analysis and plotting notebook
Load trained models and create visualizations
"""

import os
import torch
import pickle
import warnings
import numpy as np
import pandas as pd
import proplot as pplt
import scipy.stats as stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from model_classes import BASELINE, MLP, MLPMODEL

pplt.rc['figure.dpi'] = 100
warnings.filterwarnings('ignore')

def inverse_log_normalize(normdata, c=1.0):
    """Inverse of log normalization"""
    return np.exp(normdata) - c

def load_data():
    """Load the prepared data splits"""
    with open('data/data_splits.pkl', 'rb') as f:
        data_splits = pickle.load(f)
    return data_splits

def load_results():
    """Load training results"""
    with open('results/baseline_results.pkl', 'rb') as f:
        baselineresults = pickle.load(f)
    
    with open('results/mlp_results.pkl', 'rb') as f:
        mlpresults = pickle.load(f)
    
    return baselineresults, mlpresults

# Load data and results
data_splits = load_data()
baselineresults, mlpresults = load_results()

# Extract test data
xtest = data_splits['xtest']
ytest = data_splits['ytest']

# Plot training losses
fig, axs = pplt.subplots(nrows=1, ncols=2, refwidth=4, refheight=2, sharex=True, sharey=False)
axs.format(xlabel='Epoch', xlim=(0, 35))
axs[0].format(title='MSE Loss Comparison', ylabel='MSE (mm/day)$^2$', ylim=(260, 350), yticks=10, yminorticks='none')
axs[1].format(title='MAE Loss Comparison', ylabel='MAE (mm/day)', ylim=(4.7, 5.4))

colors = ['blue5', 'blue9', 'red6', 'red9']

# Organize models by activation and log transform status
model_groups = {
    'linear_regular': [],
    'linear_log': [],
    'relu_regular': [],
    'relu_log': []
}

# Categorize models based on their properties
for modelname, result in mlpresults.items():
    if 'linear' in modelname.lower():
        activation = 'linear'
    elif 'relu' in modelname.lower():
        activation = 'relu'
    else:
        continue
    
    if result.get('logtransform', False):
        model_groups[f'{activation}_log'].append(modelname)
    else:
        model_groups[f'{activation}_regular'].append(modelname)

# Plot the results
for i, (ax, loss_type) in enumerate(zip(axs, ['mse', 'mae'])):
    # Find models with the right loss type
    linear_regular = [m for m in model_groups['linear_regular'] if loss_type in m]
    relu_regular = [m for m in model_groups['relu_regular'] if loss_type in m]
    linear_log = [m for m in model_groups['linear_log'] if loss_type in m]
    relu_log = [m for m in model_groups['relu_log'] if loss_type in m]
    
    if linear_regular:
        modelname = linear_regular[0]
        ax.plot(mlpresults[modelname]['trainlosses'], color=colors[0], linestyle='--', 
                label='Linear Training' if i == 0 else None)
        ax.plot(mlpresults[modelname]['validlosses'], color=colors[0], linestyle='-', 
                label='Linear Validation' if i == 0 else None)
    
    if relu_regular:
        modelname = relu_regular[0]
        ax.plot(mlpresults[modelname]['trainlosses'], color=colors[1], linestyle='--', 
                label='Nonlinear Training' if i == 0 else None)
        ax.plot(mlpresults[modelname]['validlosses'], color=colors[1], linestyle='-', 
                label='Nonlinear Validation' if i == 0 else None)
    
    if linear_log:
        modelname = linear_log[0]
        ax.plot(mlpresults[modelname]['trainlosses'], color=colors[2], linestyle='--', 
                label='Linear Training (Log-Normalized)' if i == 0 else None)
        ax.plot(mlpresults[modelname]['validlosses'], color=colors[2], linestyle='-', 
                label='Linear Validation (Log-Normalized)' if i == 0 else None)
    
    if relu_log:
        modelname = relu_log[0]
        ax.plot(mlpresults[modelname]['trainlosses'], color=colors[3], linestyle='--', 
                label='Nonlinear Training (Log-Normalized)' if i == 0 else None)
        ax.plot(mlpresults[modelname]['validlosses'], color=colors[3], linestyle='-', 
                label='Nonlinear Validation (Log-Normalized)' if i == 0 else None)

fig.legend(loc='b', ncols=2)
pplt.show()

# Preprocess all data for plotting
processed_data = {}
ytrue_raw = ytest['pr'].values.flatten()

for modelname, result in mlpresults.items():
    if result.get('logtransform', False):
        ypred_raw = inverse_log_normalize(result['testoutputs'], c=1.0).flatten()
    else:
        ypred_raw = result['testoutputs'].flatten()
    
    mask = (ytrue_raw > 0) & (ypred_raw > 0) & ~np.isnan(ytrue_raw) & ~np.isnan(ypred_raw)
    
    processed_data[modelname] = {
        'ytrue': ytrue_raw[mask],
        'ypred': ypred_raw[mask],
        'description': result['description']
    }

for modelname, result in baselineresults.items():
    ypred_raw = result['testoutputs'].flatten()
    mask = (ytrue_raw > 0) & (ypred_raw > 0) & ~np.isnan(ytrue_raw) & ~np.isnan(ypred_raw)
    
    processed_data[modelname] = {
        'ytrue': ytrue_raw[mask],
        'ypred': ypred_raw[mask],
        'description': result['description']
    }

# Calculate global limits
all_ytrue = np.concatenate([data['ytrue'] for data in processed_data.values()])
all_ypred = np.concatenate([data['ypred'] for data in processed_data.values()])
globalmin = min(all_ytrue.min(), all_ypred.min())
globalmax = max(all_ytrue.max(), all_ypred.max())
globalmin = globalmin / (10**(0.1*(np.log10(globalmax) - np.log10(globalmin))))
globalmax = globalmax * (10**(0.1*(np.log10(globalmax) - np.log10(globalmin))))

# Plot actual vs predicted
totalmodels = len(processed_data)
ncols = 4
nrows = (totalmodels + ncols - 1) // ncols

fig, axs = pplt.subplots(nrows=nrows, ncols=ncols, refwidth=2, share=True)
axs.format(xlabel='True Precipitation (mm/day)', xscale='log', xformatter='log', xlim=[globalmin, globalmax],
           ylabel='Predicted Precipitation (mm/day)', yscale='log', yformatter='log', ylim=[globalmin, globalmax])

bins = 100
xedges = np.logspace(np.log10(globalmin), np.log10(globalmax), bins + 1)
yedges = np.logspace(np.log10(globalmin), np.log10(globalmax), bins + 1)

for plotidx, (modelname, data) in enumerate(processed_data.items()):
    row = plotidx // ncols
    col = plotidx % ncols
    
    ytrue = data['ytrue']
    ypred = data['ypred']
    
    minlen = min(len(ytrue), len(ypred))
    ytrue = ytrue[:minlen]
    ypred = ypred[:minlen]
    
    hist, _, _ = np.histogram2d(ytrue, ypred, bins=(xedges, yedges))
    hist = np.ma.masked_where(hist == 0, hist)
    
    mesh = axs[row, col].pcolormesh(xedges, yedges, hist.T, cmap='ColdHot', norm='log', levels=100)
    axs[row, col].plot([globalmin, globalmax], [globalmin, globalmax], color='k', linestyle='--')
    axs[row, col].format(title=data['description'])
    
    r2 = r2_score(ytrue, ypred)
    axs[row, col].text(0.05, 0.90, f'R$^2$ = {r2:.3f}', transform=axs[row, col].transAxes)

# Hide unused subplots
for i in range(len(processed_data), nrows * ncols):
    row = i // ncols
    col = i % ncols
    axs[row, col].axis('off')

fig.colorbar(mesh, loc='r', label='Counts', ticks=[0.1, 1, 10, 100, 1000, 10000, 100000])
pplt.show()

# Calculate metrics for all models
allmetrics = {}
for modelname, data in processed_data.items():
    allmetrics[modelname] = {
        'MSE': mean_squared_error(data['ytrue'], data['ypred']),
        'RMSE': np.sqrt(mean_squared_error(data['ytrue'], data['ypred'])),
        'MAE': mean_absolute_error(data['ytrue'], data['ypred']),
        'R2': r2_score(data['ytrue'], data['ypred']),
        'Title': data['description']
    }

metricsdf = pd.DataFrame.from_dict(allmetrics, orient='index')
print("Model Performance Metrics:")
print(metricsdf.round(3))