# TEP-NET Model Explainability

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Embedding, MultiHeadAttention, Dense, GlobalAveragePooling1D, Dropout, Concatenate, Layer, Flatten
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import AUC, Precision, Recall
from tensorflow.keras.models import load_model
import keras
from sklearn.metrics import confusion_matrix
import numpy as np
import pandas as pd
import os
import h5py
import sys
import collections
import matplotlib.pyplot as plt
import shap
from lime.lime_tabular import LimeTabularExplainer

In [None]:
path = ''

def load_hdf5_in_chunks(h5_path, chunk_size=5000):
    with h5py.File(h5_path, 'r') as f:
        total_samples = f['TCR'].shape[0]
        meta_keys = list(f['meta'].keys())

        for i in range(0, total_samples, chunk_size):
            end = min(i + chunk_size, total_samples)

            TCR_chunk = f['TCR'][i:end]
            epitope_chunk = f['epitope'][i:end]

            meta_chunk = {
                key: f[f'meta/{key}'][i:end] for key in meta_keys
            }

            df_chunk = pd.DataFrame(meta_chunk)
            df_chunk['TCR'] = list(TCR_chunk)
            df_chunk['epitope'] = list(epitope_chunk)

            yield df_chunk

def load_data(filename):
    df = pd.DataFrame()
    
    for df_chunk in load_hdf5_in_chunks(f'{path}{filename}', chunk_size=5000):
        df = pd.concat([df, df_chunk], sort=False)  
        
    return df

In [None]:
def preprocess_data(df):
    # Create a copy to avoid SettingWithCopyWarning
    df = df.copy()
    
    # Extract feature columns and labels
    X_features = df[FEATURE_COLUMNS].values
    y_labels = df['binding'].values

    # Convert all to TensorFlow tensors
    X_tcr = np.stack(df['TCR'].values)
    X_epitope = np.stack(df['epitope'].values)
    X_features = tf.convert_to_tensor(X_features, dtype=tf.float32)
    y_labels = tf.convert_to_tensor(y_labels, dtype=tf.float32)
    
    return X_tcr, X_epitope, X_features, y_labels

def preprocess_tpp(df):
    # Create a copy to avoid SettingWithCopyWarning
    df = df.copy()
       
    # Extract feature columns and labels
    X_features = df[FEATURE_COLUMNS].values
    y_labels = df['binding'].values

    # Convert all to TensorFlow tensors
    X_tcr = np.stack(df['TCR'].values)
    X_epitope = np.stack(df['epitope'].values)
    X_features = tf.convert_to_tensor(X_features, dtype=tf.float32)
    y_labels = tf.convert_to_tensor(y_labels, dtype=tf.float32)
    
    return X_tcr, X_epitope, X_features, y_labels

In [None]:
NUM_SAMPLES = 50000
EMBEDDING = 64
EMBEDDING_TYPE = 'pca'
FEATURE_COLUMNS = ['TCR_KF7', 'TCR_KF1', 'TCR_hydrophobicity', 'TCR_aromaticity', 
                   'TCR_isoelectric_point', 'TCR_instability_index', 
                   'epitope_KF7', 'epitope_KF1','epitope_hydrophobicity', 'epitope_aromaticity',
                   'epitope_isoelectric_point', 'epitope_instability_index']

In [None]:
df_test = load_data(f'test_ProtBERT_{EMBEDDING}_{EMBEDDING_TYPE}.h5')
X_test_tcr, X_test_epitope, X_test_features, y_test_labels = preprocess_data(df_test)

In [None]:
# Separate binding and non-binding samples
df_negative = df_test[df_test['binding'] == 0]
df_positive = df_test[df_test['binding'] == 1]

X_positive_tcr, X_positive_epitope, X_positive_features, y_positive_labels = preprocess_data(df_positive)
X_negative_tcr, X_negative_epitope, X_negative_features, y_negative_labels = preprocess_data(df_negative)

In [None]:
df_test = load_data(f'test_ProtBERT_{EMBEDDING}_{EMBEDDING_TYPE}.h5')
X_test_tcr, X_test_epitope, X_test_features, y_test_labels = preprocess_data(df_test)

## Import Model

In [None]:
from tensorflow.keras.saving import register_keras_serializable

@register_keras_serializable()
def f1_score_metric(y_true, y_pred):
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32))
    fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32))
    fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32))s
    
    precision = tp / (tp + fp + tf.keras.backend.epsilon())
    recall = tp / (tp + fn + tf.keras.backend.epsilon())
    
    f1 = 2 * (precision * recall) / (precision + recall + tf.keras.backend.epsilon())
    return f1

@register_keras_serializable()
class ExpandDimsLayer(Layer):
    def __init__(self, axis=1, **kwargs):
        super(ExpandDimsLayer, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.expand_dims(inputs, axis=self.axis)

    def get_config(self):
        # Add the configuration parameters
        config = super(ExpandDimsLayer, self).get_config()
        config.update({
            "axis": self.axis
        })
        return config

@register_keras_serializable()
class PiecewiseLinearEncoding(Layer):
    def __init__(self, bins, **kwargs):
        super(PiecewiseLinearEncoding, self).__init__(**kwargs)
        self.bins = tf.convert_to_tensor(bins, dtype=tf.float32)
        self.num_bins = len(bins) - 1

    def call(self, inputs):
        # Expand input to shape [batch_size, num_features, 1]
        inputs_expanded = tf.expand_dims(inputs, axis=-1)
        
        # Compute the widths of bins
        bin_widths = self.bins[1:] - self.bins[:-1]

        # Compute piecewise linear encoding
        bin_edges = (inputs_expanded - self.bins[:-1]) / bin_widths
        bin_edges = tf.clip_by_value(bin_edges, 0.0, 1.0)

        return bin_edges

    def get_config(self):
        config = super(PiecewiseLinearEncoding, self).get_config()
        config.update({
            "bins": self.bins.numpy().tolist()
        })
        return config

@register_keras_serializable()
class PeriodicEmbeddings(Layer):
    def __init__(self, num_frequencies=16, **kwargs):
        super(PeriodicEmbeddings, self).__init__(**kwargs)
        self.num_frequencies = num_frequencies
        self.freqs = tf.Variable(
            initial_value=tf.random.uniform(
                shape=(num_frequencies,), minval=0.1, maxval=1.0
            ),
            trainable=True,
        )

    def call(self, inputs):
        # Shape of inputs: [batch_size, num_features]
        inputs_expanded = tf.expand_dims(inputs, axis=-1)  # [batch_size, num_features, 1]
        periodic_features = tf.concat(
            [
                tf.sin(2 * np.pi * inputs_expanded * self.freqs),
                tf.cos(2 * np.pi * inputs_expanded * self.freqs),
            ],
            axis=-1,
        )

        return periodic_features

    def get_config(self):
        config = super(PeriodicEmbeddings, self).get_config()
        config.update({
            "num_frequencies": self.num_frequencies
        })
        return config

In [None]:
path_model = ''
model = keras.saving.load_model(f"{path_model}V3_model_1-5_64_pca.keras",
                                custom_objects={"ExpandDimsLayer": ExpandDimsLayer, "PeriodicEmbeddings": PeriodicEmbeddings, "PiecewiseLinearEncoding": PiecewiseLinearEncoding, "f1_score_metric": f1_score_metric},
                                safe_mode=False)

# Lime

In [None]:
from lime.lime_tabular import LimeTabularExplainer
import numpy as np

feature_names = FEATURE_COLUMNS
class_names = ['non-binding', 'binding']

# Convert tensors to NumPy arrays
X_test_features_np = X_test_features.numpy()
y_test_np = y_test_labels.numpy()

explainer = LimeTabularExplainer(
    training_data=X_test_features_np,
    feature_names=feature_names,
    class_names=class_names,
    mode='classification',
    discretize_continuous=True
)

In [None]:
def lime_predict_fn(physicochemical_input_np):
    dummy_tcr = np.mean(X_test_tcr, axis=0, keepdims=True)
    dummy_epitope = np.mean(X_test_epitope, axis=0, keepdims=True)

    dummy_tcr_batch = np.repeat(dummy_tcr, len(physicochemical_input_np), axis=0)
    dummy_epitope_batch = np.repeat(dummy_epitope, len(physicochemical_input_np), axis=0)

    preds = model.predict({
        "TCR_Input": dummy_tcr_batch,
        "Epitope_Input": dummy_epitope_batch,
        "Physicochemical_Features": tf.convert_to_tensor(physicochemical_input_np, dtype=tf.float32)
    }, verbose=0)

    return np.hstack([1 - preds, preds])

## Check one sample

In [None]:
sample_idx = 44
exp = explainer.explain_instance(
    data_row=X_test_features_np[sample_idx],
    predict_fn=lime_predict_fn,
    num_features=10
)

exp.show_in_notebook(show_table=True, show_all=False)

## Check multiple sample

In [None]:
import collections
import pandas as pd
import matplotlib.pyplot as plt

def aggregate_lime_explanations(
    X_features_np,
    num_samples,
    num_features=10,
    random_seed=42
):
    np.random.seed(random_seed)
    selected_indices = np.random.choice(len(X_features_np), size=num_samples, replace=False)

    aggregate_weights = collections.defaultdict(float)
    feature_counts = collections.defaultdict(int)

    for idx in selected_indices:
        exp = explainer.explain_instance(
            data_row=X_features_np[idx],
            predict_fn=lime_predict_fn,
            num_features=num_features
        )
        
        for feature, weight in exp.as_list():
            aggregate_weights[feature] += weight
            feature_counts[feature] += 1

    aggregated = {
        feature: aggregate_weights[feature] / feature_counts[feature]
        for feature in aggregate_weights
    }

    df_lime = pd.DataFrame(
        list(aggregated.items()),
        columns=["Feature (Condition)", "Mean LIME Weight"]
    ).sort_values(by="Mean LIME Weight", ascending=False)

    return df_lime


df_lime = aggregate_lime_explanations(X_test_features_np, NUM_SAMPLES)

In [None]:
plt.figure(figsize=(10, 10))
plt.barh(df_lime["Feature (Condition)"], df_lime["Mean LIME Weight"])
plt.gca().invert_yaxis()
#plt.title("Aggregated LIME Feature Importances (Top Conditions)")
plt.xlabel("Mean LIME Weight")
plt.tight_layout()
plt.savefig(f'{path}aggregated_lime_feature_importance_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
plt.show()

### Check binding vs non binding

In [None]:
pos_idx = np.where(y_val_np == 1)[0][0]
neg_idx = np.where(y_val_np == 0)[0][0]

data_pos = X_test_features[pos_idx].numpy() if tf.is_tensor(X_test_features[pos_idx]) else X_test_features[pos_idx]
data_neg = X_test_features[neg_idx].numpy() if tf.is_tensor(X_test_features[neg_idx]) else X_test_features[neg_idx]

exp_pos = explainer.explain_instance(
    data_row=data_pos,
    predict_fn=lime_predict_fn,
    num_features=10
)

exp_neg = explainer.explain_instance(
    data_row=data_neg,
    predict_fn=lime_predict_fn,
    num_features=10
)

In [None]:
import matplotlib.pyplot as plt

y_val_np = y_test_labels.numpy() if tf.is_tensor(y_test_labels) else y_test_labels

pos_idx = np.where(y_val_np == 1)[0][0]
neg_idx = np.where(y_val_np == 0)[0][0]

data_pos = X_test_features[pos_idx].numpy() if tf.is_tensor(X_test_features[pos_idx]) else X_test_features[pos_idx]
data_neg = X_test_features[neg_idx].numpy() if tf.is_tensor(X_test_features[neg_idx]) else X_test_features[neg_idx]

exp_pos = explainer.explain_instance(data_row=data_pos, predict_fn=lime_predict_fn, num_features=10)
exp_neg = explainer.explain_instance(data_row=data_neg, predict_fn=lime_predict_fn, num_features=10)

exp_pos.show_in_notebook(show_table=True)
exp_neg.show_in_notebook(show_table=True)

fig1 = exp_pos.as_pyplot_figure(label=1)
fig1.suptitle("Binding Sample", fontsize=14)

fig2 = exp_neg.as_pyplot_figure(label=1)
fig2.suptitle("Non-binding Sample", fontsize=14)

plt.show()

### Gradient × Input (Saliency Map-like) for TCR & Epitope Embeddings
You can compute the gradient of the model output w.r.t. the input embeddings to see which parts of the input embeddings were most influential.

In [None]:
import tensorflow as tf
import numpy as np

def compute_input_gradients(model, tcr_input, epitope_input, feature_input):
    tcr_input = tf.convert_to_tensor(tcr_input[None, ...])
    epitope_input = tf.convert_to_tensor(epitope_input[None, ...])
    feature_input = tf.convert_to_tensor(feature_input[None, ...])

    with tf.GradientTape() as tape:
        tape.watch([tcr_input, epitope_input])
        output = model({
            "TCR_Input": tcr_input,
            "Epitope_Input": epitope_input,
            "Physicochemical_Features": feature_input
        })
    
    # Gradients of output w.r.t. inputs
    grad_tcr, grad_epitope = tape.gradient(output, [tcr_input, epitope_input])

    # Saliency = abs(gradient × input)
    saliency_tcr = tf.reduce_sum(tf.abs(grad_tcr * tcr_input), axis=-1).numpy()[0]
    saliency_epitope = tf.reduce_sum(tf.abs(grad_epitope * epitope_input), axis=-1).numpy()[0]

    return saliency_tcr, saliency_epitope


In [None]:
import matplotlib.pyplot as plt

instance_idx = 0

sal_tcr, sal_epi = compute_input_gradients(
    model,
    tcr_input=X_test_tcr[instance_idx],
    epitope_input=X_test_epitope[instance_idx],
    feature_input=X_test_features[instance_idx]
)

plt.figure(figsize=(12, 3))
plt.bar(range(26), sal_tcr)
plt.title("TCR Embedding Saliency (Gradient × Input)")
plt.xlabel("TCR Position")
plt.ylabel("Importance")
plt.show()

plt.figure(figsize=(12, 3))
plt.bar(range(24), sal_epi)
plt.title("Epitope Embedding Saliency (Gradient × Input)")
plt.xlabel("Epitope Position")
plt.ylabel("Importance")
plt.show()

### Multiple Saliency Gradients

In [None]:
def compute_batch_saliency(model, tcr_batch, epitope_batch, feature_batch):
    tcr_batch = tf.convert_to_tensor(tcr_batch, dtype=tf.float32)
    epitope_batch = tf.convert_to_tensor(epitope_batch, dtype=tf.float32)
    feature_batch = tf.convert_to_tensor(feature_batch, dtype=tf.float32)

    with tf.GradientTape() as tape:
        tape.watch([tcr_batch, epitope_batch])
        outputs = model({
            "TCR_Input": tcr_batch,
            "Epitope_Input": epitope_batch,
            "Physicochemical_Features": feature_batch
        })  # shape: (batch_size, 1)

    grad_tcr, grad_epitope = tape.gradient(outputs, [tcr_batch, epitope_batch])

    # Compute saliency: |gradient × input|, then sum over embedding dim
    saliency_tcr = tf.reduce_sum(tf.abs(grad_tcr * tcr_batch), axis=-1).numpy()
    saliency_epi = tf.reduce_sum(tf.abs(grad_epitope * epitope_batch), axis=-1).numpy()

    return saliency_tcr, saliency_epi

def compute_saliency_over_dataset(model, X_tcr, X_epi, X_feat, batch_size=128):
    total_saliency_tcr = []
    total_saliency_epi = []

    num_samples = X_tcr.shape[0]
    for i in range(0, num_samples, batch_size):
        tcr_batch = X_tcr[i:i + batch_size]
        epi_batch = X_epi[i:i + batch_size]
        feat_batch = X_feat[i:i + batch_size]

        sal_tcr, sal_epi = compute_batch_saliency(model, tcr_batch, epi_batch, feat_batch)
        total_saliency_tcr.append(sal_tcr)
        total_saliency_epi.append(sal_epi)

    all_saliency_tcr = np.vstack(total_saliency_tcr)
    all_saliency_epi = np.vstack(total_saliency_epi)

    mean_saliency_tcr = np.mean(all_saliency_tcr, axis=0)
    mean_saliency_epi = np.mean(all_saliency_epi, axis=0)

    return mean_saliency_tcr, mean_saliency_epi


n_samples = NUM_SAMPLES
X_tcr_batch = X_test_tcr[:n_samples]
X_epi_batch = X_test_epitope[:n_samples]
X_feat_batch = X_test_features[:n_samples]

mean_sal_tcr, mean_sal_epi = compute_saliency_over_dataset(model, X_tcr_batch, X_epi_batch, X_feat_batch, batch_size=128)

plt.figure(figsize=(12, 3))
plt.bar(range(len(mean_sal_tcr)), mean_sal_tcr)
#plt.title(f"Average TCR Embedding Saliency (n={n_samples})")
plt.xlabel("TCR Position")
plt.ylabel("Mean Importance")
plt.tight_layout()
plt.savefig(f'{path}average_tcr_embedding_saliency_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
plt.show()

plt.figure(figsize=(12, 3))
plt.bar(range(len(mean_sal_epi)), mean_sal_epi)
#plt.title(f"Average Epitope Embedding Saliency (n={n_samples})")
plt.xlabel("Epitope Position")
plt.ylabel("Mean Importance")
plt.tight_layout()
plt.savefig(f'{path}average_epitope_embedding_saliency_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
plt.show()

### Separate saliency by predicted label (binding vs. non-binding)

In [None]:
def split_by_prediction(model, X_tcr, X_epitope, X_features, threshold=0.5, batch_size=128):
    preds = []
    tcr_chunks = []
    epi_chunks = []
    feat_chunks = []

    num_samples = X_tcr.shape[0]
    for i in range(0, num_samples, batch_size):
        tcr_batch = X_tcr[i:i + batch_size]
        epi_batch = X_epitope[i:i + batch_size]
        feat_batch = X_features[i:i + batch_size]

        pred_batch = model.predict({
            "TCR_Input": tcr_batch,
            "Epitope_Input": epi_batch,
            "Physicochemical_Features": feat_batch
        }, verbose=0).reshape(-1)

        preds.append(pred_batch)
        tcr_chunks.append(tcr_batch)
        epi_chunks.append(epi_batch)
        feat_chunks.append(feat_batch)

    preds = np.concatenate(preds)
    X_tcr_all = np.vstack(tcr_chunks)
    X_epi_all = np.vstack(epi_chunks)
    X_feat_all = np.vstack(feat_chunks)

    is_binding = preds >= threshold
    is_nonbinding = ~is_binding

    return (
        X_tcr_all[is_binding], X_epi_all[is_binding], X_feat_all[is_binding],
        X_tcr_all[is_nonbinding], X_epi_all[is_nonbinding], X_feat_all[is_nonbinding]
    )

In [None]:
n_samples = NUM_SAMPLES

X_batch_tcr = X_test_tcr[:n_samples]
X_batch_epi = X_test_epitope[:n_samples]
X_batch_feat = X_test_features[:n_samples]

X_tcr_bind, X_epi_bind, X_feat_bind, X_tcr_non, X_epi_non, X_feat_non = split_by_prediction(
    model, X_batch_tcr, X_batch_epi, X_batch_feat, batch_size=128
)

mean_sal_tcr_bind, mean_sal_epi_bind = compute_saliency_over_dataset(model, X_tcr_bind, X_epi_bind, X_feat_bind, batch_size=128)
mean_sal_tcr_non, mean_sal_epi_non = compute_saliency_over_dataset(model, X_tcr_non, X_epi_non, X_feat_non, batch_size=128)

def plot_saliency_comparison(sal_class0, sal_class1, title, labels):
    x = np.arange(len(sal_class0))
    width = 0.35

    plt.figure(figsize=(12, 4))
    plt.bar(x - width/2, sal_class0, width, label='Non-binding', color='steelblue')
    plt.bar(x + width/2, sal_class1, width, label='Binding', color='darkorange')
    plt.xlabel("Position")
    plt.ylabel("Mean Saliency")
    #plt.title(title)
    plt.xticks(ticks=x, labels=labels)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{path}{title}_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
    plt.show()

plot_saliency_comparison(mean_sal_tcr_non, mean_sal_tcr_bind, "TCR Saliency by Predicted Class", labels=range(26))

plot_saliency_comparison(mean_sal_epi_non, mean_sal_epi_bind, "Epitope Saliency by Predicted Class", labels=range(24))

### Analyze true vs predicted labels instead (instead of predicted-only)

In [None]:
def compute_saliency_by_true_label(model, X_tcr, X_epi, X_feat, y_true, batch_size=128):
    y_true = np.array(y_true).reshape(-1).astype(int)

    saliency_tcr_bind = []
    saliency_epi_bind = []
    saliency_tcr_non = []
    saliency_epi_non = []

    num_samples = len(y_true)

    for i in range(0, num_samples, batch_size):
        tcr_batch = np.array(X_tcr[i:i + batch_size])
        epi_batch = np.array(X_epi[i:i + batch_size])
        feat_batch = np.array(X_feat[i:i + batch_size])
        y_batch = y_true[i:i + batch_size]

        is_bind = y_batch == 1
        is_non = y_batch == 0

        if np.any(is_bind):
            sal_tcr, sal_epi = compute_batch_saliency(
                model, tcr_batch[is_bind], epi_batch[is_bind], feat_batch[is_bind]
            )
            saliency_tcr_bind.append(sal_tcr)
            saliency_epi_bind.append(sal_epi)

        if np.any(is_non):
            sal_tcr, sal_epi = compute_batch_saliency(
                model, tcr_batch[is_non], epi_batch[is_non], feat_batch[is_non]
            )
            saliency_tcr_non.append(sal_tcr)
            saliency_epi_non.append(sal_epi)

    def safe_mean(stack_list, expected_dim):
        if stack_list:
            return np.mean(np.vstack(stack_list), axis=0)
        else:
            return np.zeros(expected_dim)

    tcr_len = X_tcr.shape[1]
    epi_len = X_epi.shape[1]

    mean_sal_tcr_bind = safe_mean(saliency_tcr_bind, expected_dim=tcr_len)
    mean_sal_epi_bind = safe_mean(saliency_epi_bind, expected_dim=epi_len)
    mean_sal_tcr_non = safe_mean(saliency_tcr_non, expected_dim=tcr_len)
    mean_sal_epi_non = safe_mean(saliency_epi_non, expected_dim=epi_len)

    return mean_sal_tcr_bind, mean_sal_epi_bind, mean_sal_tcr_non, mean_sal_epi_non

In [None]:
n_samples = NUM_SAMPLES
X_tcr_batch = X_test_tcr[:n_samples]
X_epi_batch = X_test_epitope[:n_samples]
X_feat_batch = X_test_features[:n_samples]
y_true_batch = y_test_labels[:n_samples]

mean_sal_tcr_bind, mean_sal_epi_bind, mean_sal_tcr_non, mean_sal_epi_non = compute_saliency_by_true_label(
    model, X_tcr_batch, X_epi_batch, X_feat_batch, y_true_batch, batch_size=128
)

def plot_saliency_comparison(sal_class0, sal_class1, title, labels):
    x = np.arange(len(sal_class0))
    width = 0.35

    plt.figure(figsize=(12, 4))
    plt.bar(x - width/2, sal_class0, width, label='True Non-binding', color='steelblue')
    plt.bar(x + width/2, sal_class1, width, label='True Binding', color='darkorange')
    plt.xlabel("Position")
    plt.ylabel("Mean Saliency")
    plt.xticks(ticks=x, labels=labels)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{path}{title}_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
    plt.show()

plot_saliency_comparison(mean_sal_tcr_non, mean_sal_tcr_bind, "TCR Saliency by True Label", labels=range(26))
plot_saliency_comparison(mean_sal_epi_non, mean_sal_epi_bind, "Epitope Saliency by True Label", labels=range(24))

#### SHAP on TCR & Epitope Embeddings

In [None]:
import shap
import numpy as np
import tensorflow as tf

X_tcr_test = X_test_tcr[:NUM_SAMPLES]
X_epi_test = X_test_epitope[:NUM_SAMPLES]
X_feat_test = X_test_features[:NUM_SAMPLES]

n_background = NUM_SAMPLES//2

X_tcr_bg = X_positive_tcr[:n_background]
X_epi_bg = X_positive_epitope[:n_background]
X_feat_bg = X_positive_features[:n_background]

In [None]:
explainer = shap.GradientExplainer(
    model,
    data=[
        tf.convert_to_tensor(X_tcr_bg),
        tf.convert_to_tensor(X_epi_bg),
        tf.convert_to_tensor(X_feat_bg)
    ]
)

In [None]:
X_tcr_test = X_tcr_test.numpy() if isinstance(X_tcr_test, tf.Tensor) else X_tcr_test
X_epi_test = X_epi_test.numpy() if isinstance(X_epi_test, tf.Tensor) else X_epi_test
X_feat_test = X_feat_test.numpy() if isinstance(X_feat_test, tf.Tensor) else X_feat_test

In [None]:
shap_values = explainer.shap_values([
    X_tcr_test,
    X_epi_test,
    X_feat_test
])

In [None]:
import matplotlib.pyplot as plt

def plot_shap_position_importance(shap_input, title):
    shap_input = np.squeeze(shap_input)
    mean_abs_importance = np.mean(np.abs(shap_input), axis=(0, 2))

    plt.figure(figsize=(12, 3))
    plt.bar(range(len(mean_abs_importance)), mean_abs_importance)
    plt.xlabel("Position")
    plt.ylabel("Mean |SHAP value|")
    plt.tight_layout()
    plt.savefig(f'{path}{title}_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
    plt.show()

plot_shap_position_importance(shap_values[0], "TCR Embedding Importance (SHAP)")
plot_shap_position_importance(shap_values[1], "Epitope Embedding Importance (SHAP)")

In [None]:
NUM_SAMPLES = 50000

X_tcr_test = X_test_tcr[:NUM_SAMPLES]
X_epi_test = X_test_epitope[:NUM_SAMPLES]
X_feat_test = X_test_features[:NUM_SAMPLES]

n_background = NUM_SAMPLES // 2
X_tcr_bg = X_positive_tcr[:n_background]
X_epi_bg = X_positive_epitope[:n_background]
X_feat_bg = X_positive_features[:n_background]

explainer = shap.GradientExplainer(
    model,
    data=[
        tf.convert_to_tensor(X_tcr_bg),
        tf.convert_to_tensor(X_epi_bg),
        tf.convert_to_tensor(X_feat_bg)
    ]
)

def compute_shap_values_in_batches(explainer, X_tcr, X_epi, X_feat, batch_size=128):
    shap_vals_tcr = []
    shap_vals_epi = []
    shap_vals_feat = []

    for i in range(0, len(X_tcr), batch_size):
        tcr_batch = X_tcr[i:i+batch_size]
        epi_batch = X_epi[i:i+batch_size]
        feat_batch = X_feat[i:i+batch_size]

        shap_batch = explainer.shap_values([
            tcr_batch,
            epi_batch,
            feat_batch
        ])

        shap_vals_tcr.append(shap_batch[0])
        shap_vals_epi.append(shap_batch[1])
        shap_vals_feat.append(shap_batch[2])

    shap_vals_tcr = np.concatenate(shap_vals_tcr, axis=0)
    shap_vals_epi = np.concatenate(shap_vals_epi, axis=0)
    shap_vals_feat = np.concatenate(shap_vals_feat, axis=0)

    return shap_vals_tcr, shap_vals_epi, shap_vals_feat

X_tcr_test = X_tcr_test.numpy() if isinstance(X_tcr_test, tf.Tensor) else X_tcr_test
X_epi_test = X_epi_test.numpy() if isinstance(X_epi_test, tf.Tensor) else X_epi_test
X_feat_test = X_feat_test.numpy() if isinstance(X_feat_test, tf.Tensor) else X_feat_test

shap_tcr, shap_epi, shap_feat = compute_shap_values_in_batches(
    explainer,
    X_tcr_test,
    X_epi_test,
    X_feat_test,
    batch_size=64
)

In [None]:
def plot_shap_position_importance(shap_input, title):
    shap_input = np.squeeze(shap_input)
    mean_abs_importance = np.mean(np.abs(shap_input), axis=(0, 2))

    plt.figure(figsize=(12, 3))
    plt.bar(range(len(mean_abs_importance)), mean_abs_importance)
    plt.xlabel("Position")
    plt.ylabel("Mean |SHAP value|")
    plt.tight_layout()
    plt.savefig(f'{path}{title}_{NUM_SAMPLES}.png', dpi=96, bbox_inches='tight')
    plt.show()
    
plot_shap_position_importance(shap_tcr, "TCR Embedding Importance (SHAP)")
plot_shap_position_importance(shap_epi, "Epitope Embedding Importance (SHAP)")