In [None]:
%%capture run_output
%matplotlib inline

import os
import sys
import datetime
import logging
import matplotlib.pyplot as plt
import time
import numpy as np

from logger.config import logger
from perceptron.data_preprocessing import load_mnist, preprocess_data
# IMPORTANT: This must be the updated version that records iteration-level losses in `loss_history`.
from perceptron.multi_class_perceptron_unified import MultiClassPerceptron  
from analysis.evaluation_functions import (
    evaluate_model,
    analyze_confusion_matrix,
    plot_confusion_matrix,
    plot_confusion_matrix_annotated,
    plot_class_metrics,
    # We'll import plot_history so we can visualize the real iteration-based train vs. test curves
    plot_history
)

def aggregate_iteration_losses(mcp, mode_name="pocket"):
    """
    Aggregates iteration-level train/test losses across all digits 
    into a single 'train_curve' and 'test_curve' by averaging.

    We assume 'mcp.loss_history[i]["train"]' is a list of training errors 
    for digit i at each iteration, and 'mcp.loss_history[i]["test"]' similarly.
    Different digits may converge early => we pad with last value.

    Returns: (train_curve, test_curve) as lists of length = max iteration across digits.
    If no test data was provided, test_curve may be empty or zero.
    """

    num_classes = mcp.num_classes
    # Find the maximum length across all classes for "train" or "test"
    max_len = 0
    for i in range(num_classes):
        length_i = len(mcp.loss_history[i]["train"])
        if length_i > max_len:
            max_len = length_i

    # We'll store "train" and "test" in a 2D list for each iteration
    all_train = []
    all_test  = []

    for i in range(num_classes):
        # train array for class i
        t_arr = mcp.loss_history[i]["train"]
        # test array for class i
        te_arr = mcp.loss_history[i]["test"]

        # If the digit converged early, pad with the final value
        if len(t_arr) < max_len:
            t_arr = t_arr + [t_arr[-1]]*(max_len - len(t_arr))
        if len(te_arr) < max_len and len(te_arr) > 0:
            te_arr = te_arr + [te_arr[-1]]*(max_len - len(te_arr))

        all_train.append(t_arr)
        # If class i has no test data (or not used), fill with 0 or last known
        if len(te_arr) == 0:
            # Means we never recorded test error
            te_arr = [0]*max_len
        all_test.append(te_arr)

    # Convert to np.array for easy mean
    all_train = np.array(all_train)  # shape (num_classes, max_len)
    all_test  = np.array(all_test)   # shape (num_classes, max_len)

    # Average across classes
    train_curve = np.mean(all_train, axis=0).tolist()
    test_curve  = np.mean(all_test, axis=0).tolist()

    return train_curve, test_curve

def run_perceptron(mode_name, use_pocket, X_train, y_train, X_test, y_test, parent_dir):
    """
    Runs the MultiClassPerceptron in either pocket or clean mode, 
    saves logs and figures to a subfolder named <mode_name>_<timestamp> under 'parent_dir'.
    Also plots a real iteration-based train vs. test curve.

    Args:
      mode_name (str): "pocket" or "clean" (label used for folder name & logs).
      use_pocket (bool): True => pocket algorithm, False => clean.
      X_train, y_train, X_test, y_test: MNIST data/labels, preprocessed.
      parent_dir (str): top-level directory for outputs (e.g. "../outputs/ComparePocketClean").

    Returns:
      run_dir (str): path to the subfolder containing logs & figures.
    """

    # 1. Create subfolder with timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(parent_dir, f"{mode_name}_{timestamp}")
    os.makedirs(run_dir, exist_ok=True)

    # 2. Filenames
    log_filename = os.path.join(run_dir, "log.txt")
    cm_path         = os.path.join(run_dir, "conf_mat.png")
    cm_annot_path   = os.path.join(run_dir, "conf_mat_annot.png")
    class_metrics_path   = os.path.join(run_dir, "class_metrics.png")
    train_vs_test_path   = os.path.join(run_dir, "train_vs_test_curve.png")

    # 3. Configure logger for this run
    logger.handlers = []
    logger.setLevel(logging.INFO)

    fh = logging.FileHandler(log_filename, mode="w")
    fh.setLevel(logging.INFO)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)

    logger.addHandler(fh)
    logger.addHandler(ch)

    # 4. Start
    start_time = time.time()
    logger.info(f"=== Starting {mode_name.upper()} run ===")
    logger.info(f"Run directory: {run_dir}")
    logger.info(f"Log file: {log_filename}")

    # 5. Train perceptron
    # IMPORTANT: pass X_test,y_test so we can track iteration-based test error
    logger.info(f"Training MultiClassPerceptron with pocket={use_pocket}")
    mcp = MultiClassPerceptron(use_pocket=use_pocket)
    # If we want iteration-based test error, pass them in fit:
    mcp.fit(X_train, y_train, X_val=None, y_val=None, X_test=X_test, y_test=y_test)

    # 6. Evaluate
    logger.info(f"Evaluating on test set ({mode_name} mode)")
    cm, acc = evaluate_model(mcp, X_test, y_test)
    analyze_confusion_matrix(cm)

    # If you have iteration/final error stats in your updated perceptron, log them:
    if hasattr(mcp, "converged_iterations"):
        logger.info("=== Additional training stats per digit ===")
        for cls in range(mcp.num_classes):
            iters = mcp.converged_iterations.get(cls, "N/A")
            ftrain_err = mcp.final_train_error.get(cls, "N/A")
            fval_err   = mcp.final_val_error.get(cls, "N/A")
            ftest_err  = mcp.final_test_error.get(cls, "N/A")
            logger.info(
                f"Digit {cls}: iters={iters}, "
                f"final_train_err={ftrain_err}, val_err={fval_err}, test_err={ftest_err}"
            )

    # 7. Plot confusion matrix
    plot_confusion_matrix(
        cm,
        title=f"Confusion Matrix ({mode_name})",
        save_path=cm_path
    )
    logger.info(f"Basic confusion matrix saved to {cm_path}")

    # 7b. Annotated confusion matrix
    plot_confusion_matrix_annotated(
        cm,
        title=f"Confusion Matrix Annotated ({mode_name})",
        save_path=cm_annot_path
    )
    logger.info(f"Annotated confusion matrix saved to {cm_annot_path}")

    # 8. Class metrics (TPR)
    tpr_values = []
    for cls in range(cm.shape[0]):
        TP = cm[cls, cls]
        FN = np.sum(cm[cls, :]) - TP
        tpr = TP / (TP + FN) if (TP + FN) > 0 else 0
        tpr_values.append(tpr)

    plot_class_metrics(
        values=tpr_values,
        metric_name=f"TPR ({mode_name})",
        classes=range(cm.shape[0]),
        save_path=class_metrics_path
    )
    logger.info(f"Class metrics plot (TPR) saved to {class_metrics_path}")
    logger.info(f"Overall Accuracy: {acc*100:.2f}%")

    # 9. REAL iteration-based train vs. test curves across classes
    train_curve, test_curve = aggregate_iteration_losses(mcp, mode_name=mode_name)

    # We'll use the existing plot_history from evaluation_functions (assuming it has signature: plot_history(train_values, val_values=None, ...):
    from analysis.evaluation_functions import plot_history
    plot_history(
        train_values=train_curve,
        val_values=test_curve,
        title=f"Real Train vs Test Curves ({mode_name})",
        ylabel="#Misclassifications (Avg across classes)",
        save_path=train_vs_test_path
    )
    logger.info(f"Iteration-based train vs. test curve saved to {train_vs_test_path}")

    # 10. End
    duration = time.time() - start_time
    logger.info(f"{mode_name.upper()} run complete. Total runtime: {duration:.2f} sec.")

    return run_dir


# -------------------------------------------------------------------
# Main single cell logic
# -------------------------------------------------------------------
top_parent_dir = "../outputs/ComparePocketClean"
os.makedirs(top_parent_dir, exist_ok=True)

start_global = time.time()
logger.info("=== Setting up environment for Pocket vs Clean comparison with REAL iteration-based curves ===")

# Load & preprocess MNIST
logger.info("Loading MNIST once for both runs")
X, y = load_mnist()
X = preprocess_data(X)
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

# Run in Pocket mode
pocket_dir = run_perceptron(
    mode_name="pocket",
    use_pocket=True,
    X_train=X_train, y_train=y_train,
    X_test=X_test, y_test=y_test,
    parent_dir=top_parent_dir
)

# Run in Clean mode
clean_dir = run_perceptron(
    mode_name="clean",
    use_pocket=False,
    X_train=X_train, y_train=y_train,
    X_test=X_test, y_test=y_test,
    parent_dir=top_parent_dir
)

end_global = time.time()
logger.info(f"Both runs complete. Total time: {end_global - start_global:.2f} seconds.")

# Append captured cell output to a single file in the top_parent_dir
compare_path = os.path.join(top_parent_dir, "Compare_output.txt")
with open(compare_path, "w", encoding="utf-8") as f:
    f.write("--- Combined cell output from Pocket vs Clean runs (with iteration-based curves) ---\n\n")
    f.write(run_output.stdout)

print(f"Comparison done. Pocket logs/plots in: {pocket_dir}")
print(f"Clean logs/plots in: {clean_dir}")
print(f"All cell output appended to: {compare_path}")
