In [2]:
%%capture run_output
%matplotlib inline

import os
import sys
import datetime
import logging
import matplotlib.pyplot as plt
import time
import numpy as np
import shutil  # to copy the compare file into each subfolder

from core.logger.config import logger
from core.perceptron.data_preprocessing import load_mnist, preprocess_data
from core.perceptron.multi_class_perceptron_unified import MultiClassPerceptron
from core.analysis.evaluation_functions import (
    evaluate_model,
    analyze_confusion_matrix,
    plot_confusion_matrix,
    plot_confusion_matrix_annotated,
    plot_class_metrics,
    plot_history
)

def aggregate_iteration_losses(mcp):
    """
    Aggregates iteration-level train/test losses across all digits 
    into an overall 'train_curve' and 'test_curve' by averaging.
    """
    num_classes = mcp.num_classes
    max_len = 0
    # Find max iteration length among classes
    for i in range(num_classes):
        length_i = len(mcp.loss_history[i]["train"])
        if length_i > max_len:
            max_len = length_i

    all_train = []
    all_test  = []
    for i in range(num_classes):
        t_arr = mcp.loss_history[i]["train"][:]
        te_arr = mcp.loss_history[i]["test"][:]

        # pad if the digit converged early
        if len(t_arr) < max_len:
            t_arr += [t_arr[-1]]*(max_len - len(t_arr))
        if len(te_arr) < max_len and len(te_arr) > 0:
            te_arr += [te_arr[-1]]*(max_len - len(te_arr))

        if not te_arr:
            te_arr = [0]*max_len

        all_train.append(t_arr)
        all_test.append(te_arr)

    all_train = np.array(all_train)  # shape (num_classes, max_len)
    all_test  = np.array(all_test)   # shape (num_classes, max_len)

    train_curve = np.mean(all_train, axis=0).tolist()
    test_curve  = np.mean(all_test, axis=0).tolist()
    return train_curve, test_curve

def run_perceptron(mode_name, use_pocket, X_train, y_train, X_test, y_test, parent_dir):
    """
    Runs the perceptron in pocket/clean mode, logs & plots to a subfolder <mode_name>_<timestamp>.
    """
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(parent_dir, f"{mode_name}_{timestamp}")
    os.makedirs(run_dir, exist_ok=True)

    # Filenames
    log_filename  = os.path.join(run_dir, "log.txt")
    conf_mat_path = os.path.join(run_dir, "conf_mat.png")
    conf_mat_annot_path = os.path.join(run_dir, "conf_mat_annot.png")
    class_metrics_path  = os.path.join(run_dir, "class_metrics.png")
    train_vs_test_path  = os.path.join(run_dir, "train_vs_test_curve.png")

    # Logger setup
    logger.handlers = []
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(log_filename, mode="w")
    fh.setLevel(logging.INFO)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)

    fmt = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    fh.setFormatter(fmt)
    ch.setFormatter(fmt)

    logger.addHandler(fh)
    logger.addHandler(ch)

    start_time = time.time()
    logger.info(f"=== Starting {mode_name.upper()} run ===")
    logger.info(f"Run directory: {run_dir}")
    logger.info(f"Log file: {log_filename}")

    # Train
    logger.info(f"Training MultiClassPerceptron with pocket={use_pocket}")
    mcp = MultiClassPerceptron(use_pocket=use_pocket)
    mcp.fit(X_train, y_train, X_val=None, y_val=None, X_test=X_test, y_test=y_test)

    # Evaluate
    logger.info(f"Evaluating on test set ({mode_name})")
    cm, acc = evaluate_model(mcp, X_test, y_test)
    analyze_confusion_matrix(cm)

    # Additional stats
    if hasattr(mcp, "converged_iterations"):
        logger.info("=== Additional training stats per digit ===")
        for cls in range(mcp.num_classes):
            iters = mcp.converged_iterations.get(cls, "N/A")
            ftrain_err = mcp.final_train_error.get(cls, "N/A")
            fval_err   = mcp.final_val_error.get(cls, "N/A")
            ftest_err  = mcp.final_test_error.get(cls, "N/A")
            logger.info(
                f"Digit {cls}: iters={iters}, "
                f"final_train_err={ftrain_err}, val_err={fval_err}, test_err={ftest_err}"
            )

    # Plot confusion matrices
    plot_confusion_matrix(cm, f"Confusion Matrix ({mode_name})", conf_mat_path)
    logger.info(f"Confusion matrix => {conf_mat_path}")

    plot_confusion_matrix_annotated(cm, f"Conf Mat Annot ({mode_name})", conf_mat_annot_path)
    logger.info(f"Annotated conf => {conf_mat_annot_path}")

    # TPR
    tpr_vals = []
    for cls in range(cm.shape[0]):
        TP = cm[cls, cls]
        FN = np.sum(cm[cls, :]) - TP
        tpr = TP / (TP + FN) if (TP+FN)>0 else 0
        tpr_vals.append(tpr)
    plot_class_metrics(
        tpr_vals, metric_name=f"TPR ({mode_name})", 
        classes=range(cm.shape[0]), 
        save_path=class_metrics_path
    )
    logger.info(f"Class metrics => {class_metrics_path}")
    logger.info(f"Overall Accuracy: {acc*100:.2f}%")

    # Iteration-based train vs test curve
    train_curve, test_curve = aggregate_iteration_losses(mcp)
    plot_history(
        train_values=train_curve,
        val_values=test_curve,
        title=f"Train vs Test Curves ({mode_name})",
        ylabel="#Misclass (avg classes)",
        save_path=train_vs_test_path
    )
    logger.info(f"Iteration-based train vs test => {train_vs_test_path}")

    duration = time.time() - start_time
    logger.info(f"{mode_name.upper()} run done. Total time={duration:.2f} sec.")

    # Close file handlers
    fh.close()
    logger.removeHandler(fh)
    ch.close()
    logger.removeHandler(ch)

    return run_dir

# ----------------- main cell ------------------
parent_dir = "./results/ComparePocketClean"
os.makedirs(parent_dir, exist_ok=True)

start_global = time.time()
logger.info("=== Running Pocket vs Clean with iteration-based curves (pandas-based fetch_openml) ===")

# 1. Load & preprocess MNIST
X, y = load_mnist()     # requires pandas
X = preprocess_data(X)  # adds bias, normalizes
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

# 2. Pocket mode
pocket_dir = run_perceptron("pocket", True, X_train, y_train, X_test, y_test, parent_dir)
# 3. Clean mode
clean_dir  = run_perceptron("clean",  False, X_train, y_train, X_test, y_test, parent_dir)

end_global = time.time()
logger.info(f"Both runs complete. total time={end_global-start_global:.2f} sec.")

compare_path = os.path.join(parent_dir, "Compare_output.txt")
with open(compare_path, "w", encoding="utf-8") as f:
    f.write("--- Combined cell output (pandas-based) ---\n\n")
    f.write(run_output.stdout)

# Copy the compare file into each run folder
shutil.copyfile(compare_path, os.path.join(pocket_dir, "Compare_output.txt"))
shutil.copyfile(compare_path, os.path.join(clean_dir,  "Compare_output.txt"))

print(f"Done.\nPocket => {pocket_dir}\nClean => {clean_dir}\nCompare => {compare_path}")
