In [1]:
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score

from shapiq_student.knn_explainer2 import KNNExplainer
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from shapiq_student.coalition_finding import greedy_coalition_finding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv("shapiq_student/movie_dataset.csv")
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
X_train = train_df[["fight_scenes","kiss_scenes"]].values
y_train = train_df["label"].values
X_test = test_df[["fight_scenes","kiss_scenes"]].values
y_test = test_df["label"].values

In [3]:
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
prediction = knn.predict(X_test)
accuracy = accuracy_score(y_test, prediction)

In [4]:
row = df[df['title'] == 'Love in Paris'].iloc[0]
x_val = [row['fight_scenes'], row['kiss_scenes']]
y_val = row['label']

In [5]:
explainer = KNNExplainer(knn, X_train, y_train, "weighted")

In [6]:
phi = explainer.weighted_knn_shapley(
    x_val=x_val, y_val=y_val,
    gamma=5, K=5)

In [7]:
print(phi)

[ 1.27949824  1.27949824  1.27949824  1.27949824  1.27949824  0.27949824
  0.11283157  0.06521252  0.04735538  0.03941887  0.03545062  0.03328612
  0.03202349  0.03124649  0.03074699 -0.05550826  0.0557372   0.0558988
  0.05601551  0.05610151  0.05616601  0.05621515  0.05625313  0.05628285
  0.05630637  0.05632519  0.0563404   0.05635278  0.05636296  0.05637138
  0.0563784   0.05638428  0.05638925  0.05639346  0.05639706 -0.02952211]


In [8]:
e_weights = {
    frozenset(): 0.0,                   # Basiswert e₀
    frozenset({0}): 0.5,
    frozenset({1}): 1.0,
    frozenset({2}): -0.2,
    frozenset({3}): 0.3,
    frozenset({0,1}): 0.4,
    frozenset({1,2}): -0.5,
    frozenset({0,1,2}): 0.7,           # Hyperkante
}

N = [0, 1, 2, 3]

In [9]:
result = greedy_coalition_finding(N = N, e_weights=e_weights, k_max = 4)

In [None]:
print(result)

{1: ({1}, {1}), 2: ({0, 1}, {0, 1}), 3: ({0, 1, 3}, {0, 1, 3}), 4: ({0, 1, 2, 3}, {0, 1, 2, 3})}
