Simple gradient-based quadratic classification model to compute $W$. 

In [1]:
# ! pip install torch scikit-learn matplotlib -q

In [10]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
import seaborn as sns

sys.path.append("..")

from src.models.quadratic import QuadraticModel, QuadMLP
from src.models.trainer import ModelTrainer
from src.config.config import cfg

np.set_printoptions(precision=2, suppress=True, linewidth=120)

In [4]:
hyp = {
    "data_path": "../data/synth/encoded_founders_composites.csv",
    "target_column": "success",
    "test_size": 0.2,
    "val_size": 0.25,
    "random_state": 4,
    "batch_size": 64,
    "lr": 0.001,
    "weight_decay": 1e-2,
    "epochs": 200,
    "patience": 10,
    "lr_decay_factor": 0.5,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

In [7]:
df = pd.read_csv(hyp["data_path"])

feature_columns = df.columns[:-3]

X = df[feature_columns].to_numpy()
X = StandardScaler().fit_transform(X)
y = df[hyp["target_column"]].values

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=hyp["test_size"], random_state=hyp["random_state"], shuffle=True)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=hyp["val_size"], random_state=hyp["random_state"], shuffle=True)

X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.FloatTensor(y_train)
X_val_tensor = torch.FloatTensor(X_val)
y_val_tensor = torch.FloatTensor(y_val)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.FloatTensor(y_test)

pos_weight = torch.tensor([(y_train == 0).sum() / (y_train == 1).sum()])

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=hyp["batch_size"])

train_loss_history, val_loss_history, test_loss_history = [], [], []
train_acc_history, val_acc_history, test_acc_history = [], [], []

In [None]:
input_dim = X_train.shape[1]
model = QuadMLP(input_dim, hidden_dim=32, rand_init=False)
# model = QuadraticModel(input_dim=input_dim)
W_init = model.W.clone().detach()

trainer = ModelTrainer(model, hyp, pos_weight)

# Train the model
trainer.train(train_loader, X_val_tensor, y_val_tensor, X_test_tensor, y_test_tensor)

# Evaluation
model.eval()
with torch.no_grad():
    test_outputs = model(X_test_tensor.to(hyp["device"]))
    test_preds = (test_outputs > 0).int().cpu().numpy()

W_final = model.get_W().detach().cpu().numpy()
b_final = model.b.item()

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(trainer.train_loss_history, label="Train Loss")
plt.plot(trainer.val_loss_history, label="Validation Loss")
plt.plot(trainer.test_loss_history, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Cross-Entropy Loss")
plt.title("Loss Over Epochs")
plt.legend()
plt.show()

# Accuracy curves
plt.figure(figsize=(8, 5))
plt.plot(trainer.train_acc_history, label="Train Accuracy")
plt.plot(trainer.val_acc_history, label="Validation Accuracy")
plt.plot(trainer.test_acc_history, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy Over Epochs")
plt.legend()
plt.show()

report = classification_report(y_test, test_preds, output_dict=True)
report_df = pd.DataFrame(report).round(3).T
report_df = report_df.drop("support", axis=1)

plt.figure(figsize=(10, 4))
sns.heatmap(report_df, annot=True, cmap="YlOrRd", fmt=".3f", cbar=False)
plt.title("Classification Report Metrics")
plt.tight_layout()
plt.show()

In [None]:
vmin, vmax = -16, 16

fig, axs = plt.subplots(1, 3, figsize=(18, 5))
im0 = axs[0].imshow(W_init.cpu().numpy(), aspect="auto", cmap="viridis", vmin=vmin, vmax=vmax)
axs[0].set_title("Initial Weight Matrix")
plt.colorbar(im0, ax=axs[0])

im1 = axs[1].imshow(W_final, aspect="auto", cmap="viridis", vmin=vmin, vmax=vmax)
axs[1].set_title("Final Weight Matrix")
plt.colorbar(im1, ax=axs[1])

W_diff = W_final - W_init.cpu().numpy()
im2 = axs[2].imshow(W_diff, aspect="auto", cmap="RdBu")
axs[2].set_title("Change in Weight Matrix")
plt.colorbar(im2, ax=axs[2])
plt.show()

In [11]:
import pickle
import re

filename = "../models/26x26/W_1.pkl"
new_filename = re.sub(r"W_(\d+)", lambda x: f"W_{int(x.group(1)) + 1}", filename)

with open(new_filename, "wb") as f:
    pickle.dump(W_final, f)

In [None]:
# Plot Feature Importance
feature_importance = np.abs(np.diag(W_final))
feature_pairs = np.abs(W_final - np.diag(np.diag(W_final)))
feature_names = df[feature_columns].columns

feature_importance_abs = np.abs(np.diag(W_final))
feature_importance_real = np.diag(W_final)
feature_importance_normalized = feature_importance_real / np.max(np.abs(feature_importance_real))

importance_df = pd.DataFrame(
    {
        "Feature": feature_names,
        "Importance": feature_importance_normalized,
        "Importance_abs": np.abs(feature_importance_normalized),
    }
).sort_values("Importance_abs", ascending=True)

plt.figure(figsize=(12, 6))
sns.barplot(data=importance_df, y="Feature", x="Importance")
plt.title("Normalized Feature Importance (Diagonal Values)")
plt.axvline(x=0, color="black", linestyle="-", alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
MATRIX = cfg.MATRIX
category_map = {}
start_idx = 0
for cat in MATRIX:
    dim = MATRIX[cat]["DIMENSION"]
    end_idx = start_idx + dim
    for i in range(start_idx, end_idx):
        category_map[feature_names[i]] = cat
    start_idx = end_idx

upper_triangle = np.triu(W_final, k=1)

pairs = []
for i in range(len(upper_triangle)):
    for j in range(i + 1, len(upper_triangle)):
        cat1 = category_map[feature_names[i]]
        cat2 = category_map[feature_names[j]]
        if cat1 != cat2:  # Only include pairs from different categories
            pairs.append((feature_names[i], feature_names[j], upper_triangle[i, j]))


interactions_df = pd.DataFrame(pairs, columns=["Feature 1", "Feature 2", "Interaction Strength"])
max_abs_interaction = np.max(np.abs(interactions_df["Interaction Strength"]))
interactions_df["Interaction Strength"] = interactions_df["Interaction Strength"] / max_abs_interaction
interactions_df["Abs_Strength"] = np.abs(interactions_df["Interaction Strength"])
interactions_df = interactions_df.sort_values("Abs_Strength", ascending=False)

plt.figure(figsize=(12, 6))
top_n = 15
top_interactions = interactions_df.head(top_n)

colors = ["red" if x < 0 else "blue" for x in top_interactions["Interaction Strength"]]
sns.barplot(
    data=top_interactions,
    y=top_interactions.apply(lambda x: f"{x['Feature 1']} × {x['Feature 2']}", axis=1),
    x="Interaction Strength",
    palette=colors,
)

plt.title(f"Top {top_n} Normalized Cross-Category Feature Interactions")
plt.xlabel("Normalized Interaction Strength")
plt.ylabel("Feature Pair")
plt.axvline(x=0, color="black", linestyle="-", alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 10))
mask = np.zeros_like(W_final)
mask[np.tril_indices_from(mask, k=-1)] = True

sns.heatmap(
    W_final,
    xticklabels=feature_names,
    yticklabels=feature_names,
    cmap="RdBu",
    center=0,
    mask=mask,
    annot=True,
    fmt=".2f",
    square=True,
)

plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.title("Feature Interaction Matrix (Upper Triangle)")
plt.tight_layout()
plt.show()

In [None]:
# Get model predictions and confidence
with torch.no_grad():
    X_test_tensor = X_test_tensor.to(hyp["device"])
    logits = model(X_test_tensor)
    probs = torch.sigmoid(logits).cpu().numpy()
    preds = (probs > 0.5).astype(int)

analysis_df = pd.DataFrame(X_test, columns=feature_names)
analysis_df["true_label"] = y_test
analysis_df["predicted"] = preds
analysis_df["confidence"] = np.abs(probs - 0.5) + 0.5
analysis_df["correct"] = analysis_df["true_label"] == analysis_df["predicted"]

n_features = len(feature_names)
n_rows = (n_features + 3) // 4

plt.figure(figsize=(20, 5 * n_rows))

for i, feature in enumerate(feature_names):
    plt.subplot(n_rows, 4, i + 1)
    sns.kdeplot(data=analysis_df[analysis_df["correct"]], x=feature, label="Correct", alpha=0.6)
    sns.kdeplot(data=analysis_df[~analysis_df["correct"]], x=feature, label="Incorrect", alpha=0.6)
    plt.title(f"{feature}")
    plt.xlabel("Standardized Value")
    if i == 0:
        plt.legend()

plt.tight_layout()
plt.show()

# Calculate and show means for all features
confidence_threshold = np.percentile(analysis_df["confidence"], 75)
high_conf_correct = analysis_df[(analysis_df["correct"]) & (analysis_df["confidence"] >= confidence_threshold)]
high_conf_incorrect = analysis_df[(~analysis_df["correct"]) & (analysis_df["confidence"] >= confidence_threshold)]

comparison_df = pd.DataFrame(
    {
        "High Conf Correct": high_conf_correct[feature_names].mean(),
        "High Conf Incorrect": high_conf_incorrect[feature_names].mean(),
        "All Data": analysis_df[feature_names].mean(),
    }
).round(3)

print("\nFeature Means for High Confidence Predictions:")
print(comparison_df)