# 03: Model Training Analysis

This notebook is for **analyzing and understanding** the model training process, not for running the main training (which is done by `src/train_model.py`).

We will:
1.  Load the saved training/evaluation metrics from `reports/metrics.csv`.
2.  Plot the key metrics (like loss, MSE, or PESQ) over time.
3.  Load the final trained model from `models/encoder_decoder.pth`.
4.  Print a summary of the model architecture to inspect its components.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch

# Add src to path to import the model definition
import sys
sys.path.append('../src')

try:
    # We assume your model architecture is defined in 'src/model.py'
    # as 'EncoderDecoder' or a similar name.
    from model import EncoderDecoder # <--- Update this if your class name is different
except ImportError:
    print("Could not import model from src/model.py. Please ensure the file and class exist.")
    EncoderDecoder = None # Set to None so cell 3 can be skipped

# --- Configuration ---
sns.set_style("whitegrid")
METRICS_FILE = Path("../reports/metrics.csv")
MODEL_FILE = Path("../models/encoder_decoder.pth")
REPORTS_PLOTS_DIR = Path("../reports/plots")
REPORTS_PLOTS_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Visualize Training & Evaluation Metrics

First, we load the `metrics.csv` file that was saved by `src/train_model.py` or `src/evaluate.py`. This file should contain the loss and other metrics for each epoch.

In [None]:
if not METRICS_FILE.exists():
    print(f"Metrics file not found at: {METRICS_FILE}")
    print("Please run train_model.py or evaluate.py first.")
    metrics_df = pd.DataFrame()
else:
    metrics_df = pd.read_csv(METRICS_FILE)
    print(f"Loaded metrics from {METRICS_FILE}:")
    display(metrics_df.head())

# Melt the DataFrame for easier plotting with seaborn
if 'epoch' in metrics_df.columns:
    metrics_melted = metrics_df.melt('epoch', var_name='Metric', value_name='Value')
else:
    metrics_melted = pd.DataFrame()

In [None]:
if not metrics_melted.empty:
    # Plot all metrics on one grid
    g = sns.FacetGrid(metrics_melted, col="Metric", col_wrap=3, sharey=False)
    g.map(sns.lineplot, "epoch", "Value", marker='o')
    g.fig.suptitle("Training & Evaluation Metrics Over Epochs", y=1.03)
    g.set_axis_labels("Epoch", "Metric Value")
    plt.tight_layout()
    plt.savefig(REPORTS_PLOTS_DIR / "03_training_metrics.png")
    plt.show()
else:
    print("No data to plot. 'epoch' column might be missing from metrics.csv.")

## 2. Load Trained Model & Inspect Architecture

Now we load the final saved model weights from `models/encoder_decoder.pth` into our model class defined in `src/model.py`. This confirms the model can be loaded successfully.

We will then print the model summary to visualize its layers and parameters.

In [None]:
if EncoderDecoder is not None and MODEL_FILE.exists():
    print(f"Loading model architecture from 'src/model.py'...")
    # Instantiate the model
    # Note: You may need to pass configuration arguments (e.g., n_features, latent_dim)
    try:
        model = EncoderDecoder().to(DEVICE)
    except Exception as e:
        print(f"Error instantiating model: {e}")
        print("You may need to update the cell above to pass arguments to your model, e.g., EncoderDecoder(n_features=80, ...)")
        model = None

    if model:
        print(f"Loading trained weights from {MODEL_FILE}...")
        try:
            # Load the state dictionary
            model.load_state_dict(torch.load(MODEL_FILE, map_location=DEVICE))
            model.eval() # Set model to evaluation mode
            print("Model loaded successfully!")
            
            # --- Print Model Summary ---
            print("\n" + "="*30)
            print(" MODEL ARCHITECTURE SUMMARY")
            print("="*30)
            print(model)
            
        except Exception as e:
            print(f"Error loading model weights: {e}")
            print("This often happens if the architecture in 'src/model.py' does not match the saved weights.")

else:
    if not MODEL_FILE.exists():
        print(f"Model file not found: {MODEL_FILE}")
    if EncoderDecoder is None:
        print("Model class not imported. Skipping model load.")

## 3. Initial Findings

* **Metrics:** The loss curves show the model converged (or diverged). The `[Metric Name]` (e.g., `val_pesq`) reached a peak of `[Value]`, indicating the model's performance on the evaluation set.
* **Model:** The `encoder_decoder.pth` file was successfully loaded into the `EncoderDecoder` architecture. The model summary confirms the layers (e.g., Conv layers, RNNs, latent dimension) are as designed in `src/model.py`.

The training appears to be complete and the resulting artifact is loadable and ready for evaluation.