# Post-hoc XAI - Training a Neural Network and explaining it with ANCHOR

In [66]:
import pandas as pd
df = pd.read_csv("wnba_clean.csv")

In [68]:
df

Unnamed: 0,shot_type,made_shot,shot_value,coordinate_x,coordinate_y,home_score,away_score,qtr,quarter_seconds_remaining,game_seconds_remaining,shot_group,distance,shot_group_encoded,shot_type_encoded
0,Jump Shot,False,0,-13,9,0,0,1,571,2371,Jump Shot,15.811388,3,29
1,Turnaround Bank Jump Shot,False,0,0,0,0,0,1,551,2351,Jump Shot,0.000000,3,8
2,Cutting Layup Shot,True,2,-21,2,0,2,1,538,2338,Layup,21.095023,1,25
3,Driving Layup Shot,True,2,0,0,2,2,1,524,2324,Layup,0.000000,1,14
4,Jump Shot,True,3,0,21,2,5,1,512,2312,Jump Shot,21.000000,3,29
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
41492,Turnaround Fade Away Jump Shot,False,0,-2,5,71,77,4,24,24,Jump Shot,5.385165,3,6
41493,Jump Shot,False,0,0,3,71,77,4,19,19,Jump Shot,3.000000,3,29
41494,Free Throw - 1 of 2,False,0,0,15,71,77,4,16,16,Free Throw,29.154759,6,53
41495,Free Throw - 2 of 2,True,1,0,15,71,78,4,16,16,Free Throw,29.154759,6,52


In [37]:
# Step 1: Load and Prepare the Dataset
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from alibi.explainers import AnchorTabular

# Load the WNBA dataset
data_path = 'wnba_clean.csv'
df = pd.read_csv(data_path)

# Define target and features
target = 'made_shot'
features = [
    'coordinate_x', 'coordinate_y', 'distance', 'shot_type_encoded','shot_group_encoded',
    'home_score', 'away_score', 'qtr', 
    'quarter_seconds_remaining', 'game_seconds_remaining'
]

# Split dataset into train and test sets
X = df[features]
y = df[target]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize numerical features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).view(-1, 1)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).view(-1, 1)

# Debugging: Check dataset shapes and features
print("X_train shape:", X_train_tensor.shape)
print("X_test shape:", X_test_tensor.shape)
print("Features used:", features)


X_train shape: torch.Size([33197, 10])
X_test shape: torch.Size([8300, 10])
Features used: ['coordinate_x', 'coordinate_y', 'distance', 'shot_type_encoded', 'shot_group_encoded', 'home_score', 'away_score', 'qtr', 'quarter_seconds_remaining', 'game_seconds_remaining']


In [38]:
class AdvancedNN(nn.Module):
    def __init__(self, input_dim):
        super(AdvancedNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.3)
        
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(64, 32)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.dropout1(self.relu1(self.bn1(self.fc1(x))))
        x = self.dropout2(self.relu2(self.bn2(self.fc2(x))))
        x = self.relu3(self.fc3(x))
        x = self.sigmoid(self.fc4(x))
        return x


# Initialize model
input_dim = X_train_tensor.shape[1]  # 10 input features
model = AdvancedNN(input_dim)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
epochs = 20  # Increase epochs for better learning
batch_size = 32

for epoch in range(epochs):
    model.train()
    permutation = torch.randperm(X_train_tensor.size(0))  # Shuffle training data
    for i in range(0, X_train_tensor.size(0), batch_size):
        indices = permutation[i:i + batch_size]
        batch_X, batch_y = X_train_tensor[indices], y_train_tensor[indices]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(batch_X)

        # Compute loss
        loss = criterion(outputs, batch_y)

        # Backward pass
        loss.backward()

        # Optimize
        optimizer.step()

    # Print loss for each epoch
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")


Epoch 1/20, Loss: 0.6986
Epoch 2/20, Loss: 0.6599
Epoch 3/20, Loss: 0.6118
Epoch 4/20, Loss: 0.6727
Epoch 5/20, Loss: 0.6060
Epoch 6/20, Loss: 0.6473
Epoch 7/20, Loss: 0.7077
Epoch 8/20, Loss: 0.5713
Epoch 9/20, Loss: 0.7823
Epoch 10/20, Loss: 0.7451
Epoch 11/20, Loss: 0.8162
Epoch 12/20, Loss: 0.5879
Epoch 13/20, Loss: 0.6633
Epoch 14/20, Loss: 0.5392
Epoch 15/20, Loss: 0.6609
Epoch 16/20, Loss: 0.7169
Epoch 17/20, Loss: 0.5934
Epoch 18/20, Loss: 0.5221
Epoch 19/20, Loss: 0.5922
Epoch 20/20, Loss: 0.5061


In [39]:
# Evaluate the model on the test set
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation for evaluation
    predictions = model(X_test_tensor)
    predictions = (predictions > 0.5).float()  # Convert probabilities to binary predictions
    accuracy = (predictions == y_test_tensor).float().mean().item()  # Calculate accuracy

print(f"Test Set Accuracy: {accuracy:.4f}")


Test Set Accuracy: 0.6687


# Now lets use some post-hoc techniques for this model


In [40]:
from alibi.explainers import AnchorTabular

# Define a prediction function for the trained neural network
def predict_fn(x):
    preds = model(torch.tensor(x, dtype=torch.float32)).detach().numpy()
    return np.hstack((1 - preds, preds))  # Return probabilities for both classes

# Initialize categorical_names dictionary
categorical_names = {}

# Map encoded categorical features to their categories
categorical_names[3] = list(df['shot_type'].unique())  # Replace 3 with index of 'shot_type_encoded'
categorical_names[4] = list(df['shot_group'].unique())  # Replace 4 with index of 'shot_group_encoded'

# Feature names
feature_names = ['coordinate_x', 'coordinate_y', 'distance', 'shot_type_encoded', 'shot_group_encoded',
                 'home_score', 'away_score', 'qtr', 'quarter_seconds_remaining', 'game_seconds_remaining']

# Initialize the AnchorTabular explainer
explainer = AnchorTabular(predict_fn, feature_names, seed=420)

# Fit the explainer on the training data
explainer.fit(X_train_scaled, categorical_names=categorical_names)

AnchorTabular(meta={
  'name': 'AnchorTabular',
  'type': ['blackbox'],
  'explanations': ['local'],
  'params': {'seed': 420, 'disc_perc': (25, 50, 75)},
  'version': '0.9.6'}
)

## Let's start by explaining the first 5 instances

In [41]:
for i in range(5):  # Explain first 5 instances
    instance = X_test_scaled[i].reshape(1, -1)
    explanation = explainer.explain(instance)
    print(f"Instance {i + 1} Explanation:")
    print(f"Anchor: {explanation.anchor}")
    print(f"Precision: {explanation.precision}")
    print(f"Coverage: {explanation.coverage}\n")

Instance 1 Explanation:
Anchor: ['-1.22 < shot_group_encoded <= -0.05']
Precision: 0.9661865998747652
Coverage: 0.5232

Instance 2 Explanation:
Anchor: ['coordinate_y > 0.51', 'coordinate_x > -0.33']
Precision: 0.9670710571923743
Coverage: 0.1825

Instance 3 Explanation:
Anchor: ['shot_group_encoded <= -0.05', 'shot_type_encoded <= -0.55', 'quarter_seconds_remaining <= -0.89', 'home_score <= -0.85']
Precision: 0.9748743718592965
Coverage: 0.0092

Instance 4 Explanation:
Anchor: ['shot_group_encoded <= -1.22', 'coordinate_x <= -1.05', 'game_seconds_remaining > 0.01', 'away_score > -0.87', 'shot_type_encoded > -0.55']
Precision: 0.9622926093514329
Coverage: 0.006

Instance 5 Explanation:
Anchor: ['-1.22 < shot_group_encoded <= -0.05']
Precision: 0.9663488502523836
Coverage: 0.526



## We can see that instance 1 and 5 are explained by a 52% Coverage anchor which means it explains with 96.6% precision 52% of the test set scenarios, generalizing very well for that subgroup
### Let's see what that istance is and what it represents by inverting the scaling of the features:

In [42]:
# Lets explore deeply the instance with best coverage we found
instance = X_test_scaled[4].reshape(1, -1)
original_values = scaler.inverse_transform(instance)
# Map the feature names to their original values
real_values = dict(zip(features, original_values[0]))

# Print the real values of the instance
print("\nOriginal Feature Values:")
for feature, value in real_values.items():
    print(f"{feature}: {value:.2f}")



Original Feature Values:
coordinate_x: 0.00
coordinate_y: 4.00
distance: 4.00
shot_type_encoded: 17.00
shot_group_encoded: 3.00
home_score: 30.00
away_score: 39.00
qtr: 2.00
quarter_seconds_remaining: 211.00
game_seconds_remaining: 1411.00


In [19]:
feature_index = features.index('shot_group_encoded')

scaled_min = -1.22
scaled_max = -0.05

# Get the mean and std from the scaler
mu = scaler.mean_[feature_index]
sigma = scaler.scale_[feature_index]

# Compute the original thresholds
real_min = scaled_min * sigma + mu
real_max = scaled_max * sigma + mu
print(f"Real Value Range for 'shot_group_encoded': {real_min:.2f} to {real_max:.2f}")

Real Value Range for 'shot_group_encoded': 1.00 to 3.01


### The results suggest that shot_group_encoded is from 1 to 3
### Let's see what shot type is the equivalent of our encoding:

In [44]:
# Add counts for shot_type and shot_type_encoded with a clean index
unique_shot_types = (
    df.groupby(['shot_group', 'shot_group_encoded'])
    .size()
    .reset_index(name='count')
    .sort_values('shot_group_encoded')
    .set_index(['shot_group', 'shot_group_encoded'])
)
print("Unique Shot Types and Encoded Values with Counts:")
print(unique_shot_types)

Unique Shot Types and Encoded Values with Counts:
                               count
shot_group shot_group_encoded       
Dunk       0                       3
Layup      1                   10820
Hook Shot  2                     675
Jump Shot  3                   21216
Other      4                       1
Tip Shot   5                     149
Free Throw 6                    8633


### It's covering all of the shots that are a layup, hook shot, and jump shot!
#### So it's safe to conclude that for every shot of this kind, the model predicts correctly 99.6% of the time, given this test set

## Let's try to select an instance close to the decision boundary, it should be explained by a very complex anchor:

In [46]:
with torch.no_grad():
    predictions = model(X_test_tensor).squeeze().numpy()  # Get probabilities
    selected_instance_idx = np.argmin(np.abs(predictions - 0.5))  # Closest to 0.5

instance = X_test_scaled[selected_instance_idx].reshape(1, -1)

explanation = explainer.explain(instance)

# Print the explanation
print("\nAnchor Explanation:")
print(f"Anchor: {explanation.anchor}")
print(f"Precision: {explanation.precision}")
print(f"Coverage: {explanation.coverage}")


Anchor Explanation:
Anchor: ['shot_group_encoded <= -0.05', 'shot_type_encoded <= -0.55', 'home_score <= -0.85', 'coordinate_x <= -0.33', 'distance <= 0.34']
Precision: 0.9505703422053232
Coverage: 0.0138


### Coverage below 5%, as expected

## Let's try now a shot with a very small distance to the hoop

In [58]:
# Define a helper function to scale and explain an instance
def explain_instance(instance, explainer, scaler, feature_names):
    # Scale the instance
    instance_scaled = scaler.transform(instance[feature_names].values.reshape(1, -1))
    
    # Generate explanation
    explanation = explainer.explain(instance_scaled)
    
    # Reverse scale the instance for better understanding
    original_values = scaler.inverse_transform(instance_scaled)
    real_values = dict(zip(feature_names, original_values[0]))
    
    # Print explanation and original values
    print("\nOriginal Feature Values:")
    for feature, value in real_values.items():
        print(f"{feature}: {value:.2f}")
    
    print("\nAnchor Explanation:")
    print(f"Anchor: {explanation.anchor}")
    print(f"Precision: {explanation.precision}")
    print(f"Coverage: {explanation.coverage}")

# Pick a close shot (low distance)
low_distance_threshold = 3  # Define a threshold for "close shot" in original units
low_distance_mask = df['distance'] < low_distance_threshold

if np.any(low_distance_mask):
    close_shot = df[low_distance_mask].iloc[0]
    explain_instance(close_shot, explainer, scaler, features)




Original Feature Values:
coordinate_x: 0.00
coordinate_y: 0.00
distance: -0.00
shot_type_encoded: 8.00
shot_group_encoded: 3.00
home_score: 0.00
away_score: 0.00
qtr: 1.00
quarter_seconds_remaining: 551.00
game_seconds_remaining: 2351.00

Anchor Explanation:
Anchor: ['-1.22 < shot_group_encoded <= -0.05']
Precision: 0.9694779116465864
Coverage: 0.5393


# We can see that this istance is being covered by the same anchor

## Let's try to see a shot that is not being covered by this anchor, like any free throw shot:

In [69]:
# Pick a free throw shot
free_throw_mask = df['shot_group'] == 'Free Throw'

if np.any(free_throw_mask):
    free_throw_shot = df[free_throw_mask].iloc[0]
    explain_instance(free_throw_shot, explainer, scaler, features)


Original Feature Values:
coordinate_x: 0.00
coordinate_y: 15.00
distance: 29.15
shot_type_encoded: 54.00
shot_group_encoded: 6.00
home_score: 4.00
away_score: 8.00
qtr: 1.00
quarter_seconds_remaining: 478.00
game_seconds_remaining: 2278.00

Anchor Explanation:
Anchor: ['shot_group_encoded > -0.05', 'shot_type_encoded > 0.43']
Precision: 0.9924812030075187
Coverage: 0.207




In [72]:
feature_index = features.index('shot_type_encoded')

scaled_min = 0.43

# Get the mean and std from the scaler
mu = scaler.mean_[feature_index]
sigma = scaler.scale_[feature_index]

# Compute the original thresholds
real_min = scaled_min * sigma + mu
print(f"Real Value Range for 'shot_group_encoded': {real_min:.2f}")

Real Value Range for 'shot_group_encoded': 37.01


#### Let's see what this values are:

In [73]:
# Filter for shot_type_encoded > 37
filtered_df = df[df['shot_type_encoded'] > 37]

# Add counts for shot_type and shot_type_encoded with a clean index
unique_shot_types_filtered = (
    filtered_df.groupby(['shot_type', 'shot_type_encoded'])
    .size()
    .reset_index(name='count')
    .sort_values('shot_type_encoded')
    .set_index(['shot_type', 'shot_type_encoded'])
)

print("Unique Shot Types with shot_type_encoded > 37 and Counts:")
print(unique_shot_types_filtered)

Unique Shot Types with shot_type_encoded > 37 and Counts:
                                                  count
shot_type                      shot_type_encoded       
Hook Driving Bank              38                    10
Running Pullup Jump Shot       39                   508
Tip Shot                       40                   149
Alley Oop Dunk Shot            41                     1
Running Dunk Shot              42                     1
Free Throw - 3 of 3            43                    95
Free Throw - 2 of 3            44                    95
Free Throw - 1 of 3            45                    95
Free Throw - Flagrant 2 of 3   46                    15
Free Throw - Flagrant 2 of 2   47                    28
Free Throw - Flagrant 1 of 3   48                    15
Free Throw - Flagrant 1 of 2   49                    28
Free Throw - Clear Path 2 of 2 50                    14
Free Throw - Clear Path 1 of 2 51                    14
Free Throw - 2 of 2            52             

### This anchor covers all the "other", tip shots, and Free throws.
### All these shots are covered with an anchor with 99.2% precision, covering over 20% of instances.