In [None]:
from pathlib import Path
import wandb
from wandb.apis.public import Run 
import pandas as pd
import shap
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from tqdm import tqdm

api = wandb.Api()

In [None]:
config_dir = Path("configs/hparam-search")
features = [
    "larger_mlp",
    "gnn",
    "invariant_centroids",
    "pooled_by_upc",
    "initial_connection_indicator",
    "balanced_edge_sampling",
    "focal_loss",
]
rows = []
for x in config_dir.iterdir():
    ternary_name = x.stem
    row = {}
    for feature, code in zip(features, map(int, ternary_name)):
        if feature != "initial_connection_indicator":
            row[feature] = float(code)
        else:
            if code == 0:
                row["initial_connection_none"] = 1.0
                row["initial_connection_nearest"] = 0.0
                row["initial_connection_nearest_below"] = 0.0
                row["initial_connection_nearest_below_per_group"] = 0.0
            elif code == 1:
                row["initial_connection_none"] = 0.0
                row["initial_connection_nearest"] = 1.0
                row["initial_connection_nearest_below"] = 0.0
                row["initial_connection_nearest_below_per_group"] = 0.0
            elif code == 2:
                row["initial_connection_none"] = 0.0
                row["initial_connection_nearest"] = 0.0
                row["initial_connection_nearest_below"] = 1.0
                row["initial_connection_nearest_below_per_group"] = 0.0
            elif code == 3:
                row["initial_connection_none"] = 0.0
                row["initial_connection_nearest"] = 0.0
                row["initial_connection_nearest_below"] = 0.0
                row["initial_connection_nearest_below_per_group"] = 1.0
            else:
                raise RuntimeError("Unexpected ternary encountered.")
    runs: list[Run] = api.runs(
        path="price-attribution",
        filters={"displayName": ternary_name, "tags": "EVAL"}
    )
    if len(runs) == 0:
        print(f"No runs yet for {ternary_name}")
        continue
    else:
        run = sorted(runs, key=lambda x: x.summary["_timestamp"], reverse=True)[0]
        row["f1"] = runs[0].summary["test/f1_mean"]
    rows.append(row)
results = pd.DataFrame(rows)

In [None]:
def get_feature_importances(X: np.ndarray, y: np.ndarray, num_bootstraps: int = 50, random_state: int = 1998):
    rng = np.random.default_rng(random_state)
    n_samples, n_features = X.shape
    importances = np.zeros((num_bootstraps, n_features))
    
    for i in tqdm(range(num_bootstraps)):
        idx = rng.choice(n_samples, n_samples, replace=True)
        X_boot, y_boot = X[idx], y[idx]
        
        model = RandomForestRegressor(n_estimators=100, random_state=rng.integers(1e9))
        model.fit(X_boot, y_boot)
        
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X_boot)
        importances[i] = np.mean(np.abs(shap_values), axis=0)
    
    return importances

In [None]:
X = results.iloc[:, :-1]
y = results.iloc[:, -1]
model = RandomForestRegressor()
model.fit(X, y)
explainer = shap.TreeExplainer(model)
explanation = explainer(X)

In [None]:
shap.plots.beeswarm(explanation)

In [None]:
clustering = shap.utils.hclust(X)
shap.plots.bar(explanation, clustering=clustering, clustering_cutoff=0.5)

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

slider = widgets.IntSlider(value=0, min=0, max=len(explanation) - 1, step=1, description='Index')
output = widgets.Output()

def update_plot(change):
    with output:
        clear_output(wait=True)
        shap.plots.waterfall(explanation[slider.value])

slider.observe(update_plot, names='value')

display(slider, output)
update_plot(None)