In [12]:
%%capture run_output
%matplotlib inline

import os
import sys
import datetime
import logging
# import matplotlib.pyplot as plt
import time
import numpy as np

# 1. Import the global logger from logging_setup
from log.config import logger  # pre-defined global logger

# 2. Imports from your code
from core.data_preprocessing import load_mnist, preprocess_data
from core.multi_class_perceptron_unified import MultiClassPerceptron
from analysis.evaluation_functions import evaluate_model, analyze_confusion_matrix, plot_confusion_matrix, plot_confusion_matrix_annotated, plot_history, plot_class_metrics
# 3. Create run folder with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = f"../outputs/{timestamp}"
os.makedirs(run_dir, exist_ok=True)

# Define paths for the log file and various images
log_filename = os.path.join(run_dir, "log.txt")
conf_mat_path = os.path.join(run_dir, "conf_mat.png")
conf_mat_annot_path = os.path.join(run_dir, "conf_mat_annot.png")
history_path = os.path.join(run_dir, "train_history.png")
class_metrics_path = os.path.join(run_dir, "class_metrics.png")

# 4. Configure the global logger we imported
logger.handlers = []
logger.setLevel(logging.INFO)

fh = logging.FileHandler(log_filename, mode="w")
fh.setLevel(logging.INFO)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
ch.setFormatter(formatter)

logger.addHandler(fh)
logger.addHandler(ch)

# 5. Start the run
start_time = time.time()
logger.info("=== Starting run in Jupyter with global logger ===")
logger.info(f"Run directory: {run_dir}")
logger.info(f"Log file: {log_filename}")

# 6. Load data
logger.info("Loading and Preprocessing MNIST...")
X, y = load_mnist()
X = preprocess_data(X)

logger.info("Splitting data")
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

# 7. Train model in pocket mode
logger.info("Training MultiClassPerceptron (pocket)")
mcp_pocket = MultiClassPerceptron(use_pocket=True)
mcp_pocket.fit(X_train, y_train)

logger.info("Evaluating on test set (pocket)")
cm, acc = evaluate_model(mcp_pocket, X_test, y_test)
analyze_confusion_matrix(cm)

# 8a. Plot & save the basic confusion matrix
plot_confusion_matrix(cm, title="Confusion Matrix", save_path=conf_mat_path)
logger.info(f"Figure saved to {conf_mat_path}")

# 8b. Plot & save the annotated confusion matrix
plot_confusion_matrix_annotated(cm, title="Confusion Matrix (Annotated)", save_path=conf_mat_annot_path)
logger.info(f"Annotated confusion matrix saved to {conf_mat_annot_path}")

# 8c. Demonstrate training history (example data)
# Suppose we recorded 4 epochs of train/val error
train_err = [0.4, 0.25, 0.15, 0.10]  # example
val_err = [0.45, 0.30, 0.20, 0.15]    # example
plot_history(train_err, val_err, title="Train vs Val Error", ylabel="Error", save_path=history_path)
logger.info(f"Training history plot saved to {history_path}")

# 8d. Demonstrate class metrics (e.g. TPR by digit)
# Let's say we computed TPR for each digit
# (In reality, you'd compute them from the confusion matrix.)
tpr_values = []
for cls in range(cm.shape[0]):
    TP = cm[cls, cls]
    FN = np.sum(cm[cls, :]) - TP
    tpr = TP / (TP + FN) if (TP + FN) > 0 else 0
    tpr_values.append(tpr)

plot_class_metrics(tpr_values, metric_name="TPR", classes=range(cm.shape[0]), save_path=class_metrics_path)
logger.info(f"Class metrics plot (TPR) saved to {class_metrics_path}")

logger.info(f"Accuracy: {acc*100:.2f}%")

# 9. End the run
duration = time.time() - start_time
logger.info(f"Run complete. Total runtime: {duration:.2f} sec.")

# 10. Append captured cell output to log
with open(log_filename, "a", encoding="utf-8") as f:
    f.write("\n--- Captured Jupyter Cell Output ---\n")
    f.write(run_output.stdout)

print(f"All cell output appended to {log_filename}")


ImportError: cannot import name 'plot_confusion_matrix_annotated' from 'analysis.evaluation_functions' (e:\dev\ml_intro_course\mmn11\perceptron\code\analysis\evaluation_functions.py)