In [None]:
import numpy as np
import lightgbm as lgbm
import seaborn as sns
import matplotlib.pyplot as plt
from ilmart.utils import is_interpretable
from ilmart.ilmart_distill import IlmartDistill
%load_ext autoreload
%autoreload 2

# Load Models

In [None]:
models_no_inter_dir = "../best_models/ilmart/without_inter"
models_inter_dir = "../best_models/ilmart/with_inter"
datasets_name = ["web30k", "istella", "yahoo"]

In [None]:
main_effects_booster = {}
for name in datasets_name:
    file_path = f"{models_no_inter_dir}/{name}.lgbm"
    main_effects_booster[name] = lgbm.Booster(model_file=file_path)

In [None]:
ilmart_booster = {}
for name in datasets_name:
    file_path = f"{models_inter_dir}/{name}.lgbm"
    ilmart_booster[name] = lgbm.Booster(model_file=file_path)

# Check interpretability

In [None]:
for name, model in main_effects_booster.items():
    print(f"Checking features used by {name}... ")
    print(f"Is interpretable? {is_interpretable(model)}")
    print(f"Number of trees: {model.num_trees()}")

In [None]:
for name, model in ilmart_booster.items():
    print(f"Checking features used by {name}... ")
    print(f"Is interpretable? {is_interpretable(model)}")
    print(f"Number of trees: {model.num_trees()}")


## Plot first components of WEB30K best model

In [None]:
best_web30k = lgbm.Booster(model_file="../best_models/ilmart/with_inter/web30k.lgbm")
distilled_web30k = IlmartDistill(best_web30k)

In [None]:
feat_imp = [(feat, imp) for feat, imp in enumerate(best_web30k.feature_importance())]

In [None]:
feat_imp.sort(key=lambda x: x[1], reverse=True)

In [None]:
ilmart_booster

In [None]:
feat_imp

In [None]:
feats_label = {
    129: "Outlink number",
    133: "QualityScore2",
    114: "LMIR.ABS-title",
    134: "Query-url click count"
}

In [None]:

plt.rcParams.update({'font.size': 25})

x_lims = [(0.01, 0.80), (0.01, 0.50), (0.01, 0.99)]
fig, axs = plt.subplots(2, 2, figsize=(13, 10))
feats_label = {
    129: "PR",
    133: "QUCC",
    114: "LMIR",
    134: "UCC"
}
for i, (feat, imp) in enumerate(feat_imp[:3]):
    current_ax = axs[i // 2][i % 2]
    x = distilled_web30k.splitting_values[feat]
    y = np.append(distilled_web30k.hist[(feat,)], distilled_web30k.hist[(feat,)][-1])
    sns.lineplot(x,y, drawstyle='steps-pre', ax=current_ax)
    current_ax.set_ylabel(r"$\tau(x_j)$", rotation=0)
    current_ax.yaxis.set_label_coords(-0.1,1.02)
    current_ax.set_xlim(np.quantile(x, x_lims[i][0]), np.quantile(x, x_lims[i][1]))
    current_ax.set_title(f"{feats_label[feat]}")
    current_ax.set_xlabel("$x_j$")

current_ax = axs[1][1]

current_ax.text(300, 0, r"$\tau_{ij}(x_i, x_j)$")

feat1 = 133
feat2 = 134
cropped_matrix = np.rot90(distilled_web30k.hist[(feat1, feat2)])
splits_feat1 = distilled_web30k.splitting_values[feat1]
splits_feat2 = distilled_web30k.splitting_values[feat2]
sns.heatmap(cropped_matrix,
            ax=current_ax,
            cmap="Blues",
            cbar_kws = dict(location="right", anchor=(0.5, .4), pad=0))
current_ax.set_xlabel(f"$x_i$\t{feats_label[feat1]}")
current_ax.set_ylabel(f"$x_j$\t{feats_label[feat2]}", rotation=0)
current_ax.yaxis.set_label_coords(-0.1,1.02)



crop = [0.1, 0.9]

reduced_yticks = np.quantile(range(cropped_matrix.shape[0]), crop)
reduced_ylabels = np.quantile(splits_feat1[:-1], crop)[::-1]

reduced_xticks = np.quantile(range(cropped_matrix.shape[1]), crop)
reduced_xlabels = np.quantile(splits_feat2[:-1], crop)


current_ax.set_yticks(reduced_yticks, labels=[f"{label:.1E}" for label in reduced_ylabels])
current_ax.set_xticks(reduced_xticks, labels=[f"{label:.1E}" for label in reduced_xlabels])

current_ax.xaxis.set_tick_params(rotation=0)
current_ax.yaxis.set_tick_params(rotation=0)


plt.tight_layout()
#plt.savefig("plots/function_plots.pdf")