# Notebook 03 — Machine Learning Models

This notebook trains predictive models on the engineered
attention microstructure dataset.

Models included:

## 1. Regime Classification (Multiclass)
Predict which behavioral regime the attention system is currently in.

## 2. 1-step Up/Down Attention Movement (Binary Classification)
Predict whether attention will increase or decrease in the next step.

## 3. 5-step Forward Attention Change (Regression)
Predict the magnitude of attention change over the next 5 steps.

We evaluate:
- Accuracy / F1
- Confusion matrices
- Feature importance
- R² for regression
- Prediction vs actual

This notebook simulates a standard quant research workflow.


In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

plt.style.use("default")

data_path = Path("..") / "data" / "ml_dataset.csv"
df = pd.read_csv(data_path)

df.head()


In [None]:
label_regime = "regime_id"
label_up1 = "label_up_1"
label_up5 = "label_up_5"
label_fwd5 = "fwd_return_5"

# all columns except labels are features
labels = [label_regime, label_up1, label_up5, label_fwd5]
feature_cols = [c for c in df.columns if c not in labels]

X = df[feature_cols]


In [None]:
y_regime = df[label_regime]

X_train, X_test, y_train, y_test = train_test_split(
    X, y_regime, test_size=0.2, random_state=42, shuffle=True
)


In [None]:
clf_regime = RandomForestClassifier(
    n_estimators=200,
    max_depth=12,
    random_state=42
)

clf_regime.fit(X_train, y_train)
y_pred = clf_regime.predict(X_test)


In [None]:
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:")
print(classification_report(y_test, y_pred))


In [None]:
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(6,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title("Regime Classifier Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()


In [None]:
importances = pd.Series(
    clf_regime.feature_importances_, index=feature_cols
).sort_values(ascending=False)

plt.figure(figsize=(8,10))
importances.head(20).plot(kind="barh")
plt.title("Top 20 Feature Importances — Regime Classifier")
plt.gca().invert_yaxis()
plt.show()


In [None]:
y_up1 = df[label_up1]

X_train2, X_test2, y_train2, y_test2 = train_test_split(
    X, y_up1, test_size=0.2, random_state=42
)


In [None]:
clf_up1 = RandomForestClassifier(
    n_estimators=200,
    max_depth=12,
    random_state=42
)

clf_up1.fit(X_train2, y_train2)
y_pred2 = clf_up1.predict(X_test2)


In [None]:
print("Accuracy:", accuracy_score(y_test2, y_pred2))
print("F1 Score:", f1_score(y_test2, y_pred2))
print("\nClassification Report:")
print(classification_report(y_test2, y_pred2))


In [None]:
cm2 = confusion_matrix(y_test2, y_pred2)

plt.figure(figsize=(5,4))
sns.heatmap(cm2, annot=True, fmt="d", cmap="Greens")
plt.title("1-step Up/Down — Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()


In [None]:
y_fwd5 = df[label_fwd5]

X_train3, X_test3, y_train3, y_test3 = train_test_split(
    X, y_fwd5, test_size=0.2, random_state=42
)


In [None]:
reg_fwd5 = RandomForestRegressor(
    n_estimators=300,
    max_depth=14,
    random_state=42
)

reg_fwd5.fit(X_train3, y_train3)
y_pred3 = reg_fwd5.predict(X_test3)


In [None]:
print("R² Score:", r2_score(y_test3, y_pred3))
print("MAE:", mean_absolute_error(y_test3, y_pred3))


In [None]:
plt.figure(figsize=(8,5))
plt.scatter(y_test3, y_pred3, alpha=0.5)
plt.title("5-step Forward Attention Forecast — Predicted vs Actual")
plt.xlabel("Actual")
plt.ylabel("Predicted")
plt.tight_layout()
plt.show()


In [None]:
import joblib

joblib.dump(clf_regime, "../models/regime_classifier.pkl")
joblib.dump(clf_up1, "../models/updown_classifier.pkl")
joblib.dump(reg_fwd5, "../models/fwd5_regressor.pkl")

print("Models saved to /models")
