In [None]:
import numpy as np
from collections import defaultdict
from sklearn.metrics import f1_score
import scipy

def parse_features(lst):
    """
    Parses a list of feature names into groups based on a "class=value" pattern.
    Features not matching the pattern are treated as individual groups.
    """
    groups = defaultdict(list)
    for i, name in enumerate(lst):
        if '=' in name:
            class_name = name.split('=', 1)[0]
            if "rel" in class_name:
                class_name = name.split("_",1)[0]
            groups[class_name].append(i)
        else:
            groups[name].append(i)
    return dict(groups)

def grouped_permutation_importance(
    model,
    X_test,
    y_test,
    feature_names,
    scoring_func=f1_score,
    n_repeats=10,
    random_state=None
):

    feature_groups = parse_features(feature_names)
    group_names = list(feature_groups.keys())
    rng = np.random.default_rng(random_state)
    is_sparse_matrix = scipy.sparse.issparse(X_test)
    # baseline score. Do I have to compute it several times?
    y_pred_base = model.predict(X_test)
    baseline_score = scoring_func(y_test, y_pred_base, average="macro")
    raw_importances = defaultdict(list)
    for group_name in group_names:
        group_indices = feature_groups[group_name]
        if not group_indices:
            raw_importances[group_name] = [0.0] * n_repeats # Importance is 0 if no columns
            continue

        for _ in range(n_repeats):
            X_permuted = X_test.copy()
            if is_sparse_matrix:
                X_permuted = X_permuted.tolil() # to lil because is faster

            for col_idx in group_indices:
                col = X_permuted[:, col_idx].toarray().ravel()
                permuted_col = rng.permutation(col)
                X_permuted[:, col_idx] = np.reshape(permuted_col, (-1, 1))
            
            X_permuted = X_permuted.tocsr()
            y_pred_permuted = model.predict(X_permuted)
            permuted_score = scoring_func(y_test, y_pred_permuted, average="macro")
            raw_importances[group_name].append(baseline_score - permuted_score)
            
    importances_mean = np.array([np.mean(raw_importances[name]) for name in group_names])
    importances_std = np.array([np.std(raw_importances[name]) for name in group_names])
    
    sorted_indices = np.argsort(importances_mean)[::-1]
    sorted_group_names = [group_names[i] for i in sorted_indices]
    sorted_importances_mean = importances_mean[sorted_indices]
    sorted_importances_std = importances_std[sorted_indices]
    res = {
        "importances_mean": sorted_importances_mean,
        "importances_std": sorted_importances_std,
        "group_names": sorted_group_names,
        "baseline_score": baseline_score
    }
    return res

## Load data

In [None]:
from utils import build_occurrences, matrix

path = "data/SUD_French-GSD-r2.15/preprocessed_data"
patterns = "patterns/patterns_subject_inv_with_gov_lemmas.txt"
data = build_occurrences(path, patterns, "sud")
X, y, feature_names = matrix(data, max_degree=1)
print(X.shape)
print(y.shape)

## Permutation Test

In [None]:
from sklearn.model_selection import train_test_split
import skglm

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

model = skglm.SparseLogisticRegression(
    fit_intercept=True,
    max_iter=20,
    max_epochs=1000,
    alpha=0.001
)

model.fit(X_train, y_train)
    
# grouped permutation importance
importance_results = grouped_permutation_importance(
    model=model,
    X_test=X_test,
    y_test=y_test, 
    feature_names=feature_names, 
    scoring_func=f1_score, 
    n_repeats=30,
    random_state=42
)

print(f"\nBaseline Validation Accuracy: {importance_results['baseline_score']:.4f}")
print("\nGrouped Permutation Importances (Mean Decrease in Macro F1):")
for i in range(len(importance_results['group_names'])):
    group = importance_results['group_names'][i]
    mean_imp = importance_results['importances_mean'][i]
    std_imp = importance_results['importances_std'][i]

## Plot

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.patches import Patch

def class_feature_to_request(s):
    _, name, node, feature = s.split(":")
    if feature == "len" : feature = "length"
    if node == "own":
        class_name = f"{name}.{feature}"
    else:
        class_name = f"{name}{node}.{feature}"
    return class_name

def plot_permutation_importances(
    group_names,
    importances_mean,
    importances_std,
    title="Grouped Permutation Importances",
    metric_name="Decrease in Score",
    reliability_std_factor=2.0,
    epsilon_threshold=1e-8
):

    if not (len(group_names) == len(importances_mean) == len(importances_std)):
        raise ValueError("Input lists/arrays must have the same length.")

    means_orig = np.array(importances_mean)
    stds_orig = np.array(importances_std)
    names_orig = np.array([class_feature_to_request(name) for name in group_names])

    is_negligible_mask = (means_orig < epsilon_threshold)
    
    plot_means = means_orig[~is_negligible_mask]
    plot_stds = stds_orig[~is_negligible_mask]
    plot_names = names_orig[~is_negligible_mask]

    if len(plot_names) == 0:
        print(f"No features to plot after filtering those with mean and std < {epsilon_threshold:.1e}.")
        return

    sorted_indices = np.argsort(plot_means)
    
    plot_names_sorted = plot_names[sorted_indices]
    plot_means_sorted = plot_means[sorted_indices]
    plot_stds_sorted = plot_stds[sorted_indices]

    # 4 categories of importance given the standard deviation on importance
    reliably_positive = "reliably_positive"
    uncertain_positive = "uncertain_positive"
    reliably_negative = "reliably_negative"
    uncertain_negative = "uncertain_negative"
    other = "other"

    legend_labels_map = {
        reliably_positive: f"Mean - {reliability_std_factor:.1f}*SD > {epsilon_threshold:.0e}",
        uncertain_positive: f"Mean > {epsilon_threshold:.0e}",
        reliably_negative: f"Mean + {reliability_std_factor:.1f}*SD < -{epsilon_threshold:.0e}",
        uncertain_negative: f"Mean < -{epsilon_threshold:.0e}",
        other: f"Near Zero"
    }

    palette = sns.color_palette("Set2", n_colors=5) 
    category_color_map = {
        reliably_positive: palette[0],
        uncertain_positive: palette[1],
        reliably_negative: palette[2],
        uncertain_negative: palette[3],
        other: palette[4]
    }
    
    colors_for_plot = []
    actual_legend_labels_in_plot = [] # the full legend labels that are used

    for mean_val, std_val in zip(plot_means_sorted, plot_stds_sorted):
        internal_cat_key = other # default category
        
        lower_bound_positive = mean_val - reliability_std_factor * std_val
        upper_bound_negative = mean_val + reliability_std_factor * std_val

        if lower_bound_positive > epsilon_threshold:
            internal_cat_key = reliably_positive
        elif upper_bound_negative < -epsilon_threshold:
            internal_cat_key = reliably_negative
        elif mean_val > epsilon_threshold:
            internal_cat_key = uncertain_positive
        elif mean_val < -epsilon_threshold: 
            internal_cat_key = uncertain_negative

        colors_for_plot.append(category_color_map[internal_cat_key])
        actual_legend_labels_in_plot.append(legend_labels_map[internal_cat_key])

    # plotting
    y_pos = np.arange(len(plot_names_sorted))
    fig, ax = plt.subplots(figsize=(12, max(6, len(plot_names_sorted) * 0.45)))
    
    ax.barh(y_pos, plot_means_sorted, xerr=plot_stds_sorted, align='center',
            color=colors_for_plot, capsize=4, ecolor='dimgray')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(plot_names_sorted)
    ax.set_xlabel(f"Mean Importance ({metric_name})")
    ax.set_title(title)
    ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
    
    # 5. Create custom legend
    legend_handles = []
    seen_legend_labels = set()
    
    # order for legend entries to appear consistently
    preferred_legend_order_labels = [
        legend_labels_map[reliably_positive],
        legend_labels_map[uncertain_positive],
        legend_labels_map[uncertain_negative],
        legend_labels_map[reliably_negative],
        legend_labels_map[other]
    ]

    for label_text in preferred_legend_order_labels:
        if label_text in actual_legend_labels_in_plot and label_text not in seen_legend_labels:
            internal_key_for_label = None
            for key, val_label in legend_labels_map.items():
                if val_label == label_text:
                    internal_key_for_label = key
                    break
            if internal_key_for_label:
                 legend_handles.append(Patch(facecolor=category_color_map[internal_key_for_label], label=label_text))
                 seen_legend_labels.add(label_text)
            
    if legend_handles:
        ax.legend(handles=legend_handles, title="Importance Interpretation", 
                  bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)

    plt.tight_layout(rect=[0, 0, 0.72, 1] if legend_handles else None)
    plt.savefig("permutation_test_subj.pdf")
    plt.show()

In [None]:
plot_permutation_importances(
    group_names=importance_results['group_names'],
    importances_mean=importance_results['importances_mean'],
    importances_std=importance_results['importances_std'],
    title="Feature Group Importances (Permutation Test)\nScope: X-[1=subj]->Y\nQ: X << Y",
    metric_name="Decrease in Macro F1-score",
    reliability_std_factor=2.0
)