# Implementing LIME using JAX

Notebook for experimenting with implementing LIME using JAX. To keep things simple, we will try to implement `LimeTabularExplainer` for the Wisconsin breast cancer dataset.

In [53]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
from jax.ops import index, index_add, index_update

import scipy
import numpy as onp

import sklearn
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.linear_model import Ridge, Lasso

from lime.lime_tabular import LimeTabularExplainer

key = random.PRNGKey(0)

## Prepare the dataset and train a model

In [54]:
data = datasets.load_breast_cancer()
X, y = data['data'], data['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
clf.score(X_test, y_test)

0.9736842105263158

# Attempt to obtain an explanation for a given data sample

In [55]:
data_row = X_test[0].reshape((1, -1))
print(data_row.shape)
print(clf.predict_proba(data_row))

(1, 30)
[[0.01 0.99]]


In [56]:
data['feature_names']

array(['mean radius', 'mean texture', 'mean perimeter', 'mean area',
       'mean smoothness', 'mean compactness', 'mean concavity',
       'mean concave points', 'mean symmetry', 'mean fractal dimension',
       'radius error', 'texture error', 'perimeter error', 'area error',
       'smoothness error', 'compactness error', 'concavity error',
       'concave points error', 'symmetry error',
       'fractal dimension error', 'worst radius', 'worst texture',
       'worst perimeter', 'worst area', 'worst smoothness',
       'worst compactness', 'worst concavity', 'worst concave points',
       'worst symmetry', 'worst fractal dimension'], dtype='<U23')

For the breast cancer dataset, all features are numerical

In [57]:
a = np.arange(0, 11).astype(onp.float32)
bins = np.percentile(a, [25, 50, 75])
bins
print(onp.digitize(a, bins))

[0 0 0 1 1 2 2 2 3 3 3]


In [58]:
def discretize(arr, qs):
    bins = np.percentile(a, qs)
    return onp.digitize(a, bins)

In [59]:
discretize(a, [25, 50, 75])

array([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3])

In [60]:
bins

DeviceArray([2.5, 5. , 7.5], dtype=float32)

## Discretize the dataset

In [61]:
%timeit all_bins = np.percentile(X_test, [25, 50, 75], axis=0).T
all_bins = np.percentile(X_test, [25, 50, 75], axis=0).T

2.04 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [62]:
print(all_bins.shape)
print(X_test.shape)

(30, 3)
(114, 30)


In [63]:
%timeit discretized = [onp.digitize(a, bins) for (a, bins) in zip(X_test.T, all_bins)]
discretized = [onp.digitize(a, bins) for (a, bins) in zip(X_test.T, all_bins)]

279 µs ± 52.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [64]:
np.array(discretized).T.shape

(114, 30)

In [65]:
def discretize(X, qs=[25, 50, 75], all_bins=None):
    if all_bins is None:
        all_bins = onp.percentile(X, qs, axis=0).T
    return (np.array([onp.digitize(a, bins) for (a, bins) in zip(X.T, all_bins)]).T, all_bins)

def discretize_jax(X, qs=[25, 50, 75], all_bins=None):
    if all_bins is None:
        all_bins = np.percentile(X, qs, axis=0).T
    return (np.array([onp.digitize(a, bins) for (a, bins) in zip(X.T, all_bins)]).T, all_bins)

X_test_disc, all_bins = discretize(X_test)
X_test_disc

DeviceArray([[1, 3, 1, ..., 2, 1, 2],
             [2, 0, 1, ..., 0, 0, 0],
             [0, 0, 0, ..., 1, 2, 3],
             ...,
             [0, 0, 1, ..., 2, 1, 3],
             [2, 0, 2, ..., 1, 1, 0],
             [3, 1, 3, ..., 3, 2, 3]], dtype=int32)

In [66]:
oe = OneHotEncoder()
X_test_onehot = oe.fit_transform(X_test_disc)

In [67]:
X_test_onehot.shape

(114, 120)

In [68]:
X_test_onehot[0].toarray()

array([[0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0.,
        0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
        0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 1., 0., 0., 0., 0., 1., 0.]])

In [69]:
%timeit X_synthetic = np.tile(X_test_onehot[0].toarray().reshape((1, -1)), (1000, 1))
X_synthetic = np.tile(X_test_onehot[0].toarray().reshape((1, -1)), (1000, 1))

58.9 ms ± 8.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [70]:
X_synthetic.shape

(1000, 120)

# Create synthetic dataset

* Based on the implementation of LIME, it looks like we first standardscale the data, add noise from N(0, 1), then rescale back to the original domain
* `with_mean=False` means that we are sampling around the given data instance
* This requires getting the mean and std. dev. of the data

In [71]:
sc = StandardScaler(with_mean=False)
sc.fit(X_train)

StandardScaler(copy=True, with_mean=False, with_std=True)

In [72]:
sc.mean_

array([1.41373011e+01, 1.92501319e+01, 9.20553407e+01, 6.55707033e+02,
       9.66194286e-02, 1.05317868e-01, 8.86511730e-02, 4.88025670e-02,
       1.81858681e-01, 6.29431429e-02, 4.08352527e-01, 1.22789495e+00,
       2.87935604e+00, 4.03608396e+01, 7.07142857e-03, 2.55471033e-02,
       3.15398914e-02, 1.16425363e-02, 2.07472923e-02, 3.83184901e-03,
       1.63006923e+01, 2.57372527e+01, 1.07502308e+02, 8.83013407e+02,
       1.33236022e-01, 2.58458879e-01, 2.73775407e-01, 1.14872464e-01,
       2.92811209e-01, 8.44838462e-02])

In [73]:
len(sc.mean_)

30

In [21]:
sc.scale_

array([3.58120907e+00, 4.43143076e+00, 2.47414754e+01, 3.57395126e+02,
       1.37498557e-02, 5.37530404e-02, 8.10522800e-02, 3.95191708e-02,
       2.78833034e-02, 6.96823466e-03, 2.64846748e-01, 5.60209985e-01,
       1.91800694e+00, 4.32565169e+01, 3.04915865e-03, 1.87691670e-02,
       3.19594908e-02, 6.38187670e-03, 8.21410123e-03, 2.77416134e-03,
       4.98288759e+00, 6.25536294e+00, 3.46495143e+01, 5.92290265e+02,
       2.24146556e-02, 1.63066710e-01, 2.15216203e-01, 6.74248592e-02,
       6.16088078e-02, 1.83893991e-02])

In [22]:
X_test[:1]

array([[1.550e+01, 2.108e+01, 1.029e+02, 8.031e+02, 1.120e-01, 1.571e-01,
        1.522e-01, 8.481e-02, 2.085e-01, 6.864e-02, 1.370e+00, 1.213e+00,
        9.424e+00, 1.765e+02, 8.198e-03, 3.889e-02, 4.493e-02, 2.139e-02,
        2.018e-02, 5.815e-03, 2.317e+01, 2.765e+01, 1.571e+02, 1.748e+03,
        1.517e-01, 4.002e-01, 4.211e-01, 2.134e-01, 3.003e-01, 1.048e-01]])

In [23]:
data_row = X_test[:1]
data_row = sc.transform(data_row)
data_row

array([[4.3281472 , 4.75692866, 4.15900824, 2.24709276, 8.14554004,
        2.92262538, 1.87780035, 2.14604705, 7.4775932 , 9.85041454,
        5.1728028 , 2.16525951, 4.91343374, 4.08031004, 2.68861052,
        2.07201524, 1.40584217, 3.35167867, 2.45675083, 2.09612899,
        4.64991425, 4.42020715, 4.53397409, 2.95125567, 6.76789341,
        2.4542103 , 1.95663706, 3.16500475, 4.87430305, 5.69893554]])

In [24]:
X_synthetic = np.tile(data_row, (1000, 1))
X_synthetic.shape

(1000, 30)

In [25]:
X_synthetic = X_synthetic + random.normal(key, (1000, 30))
X_synthetic

DeviceArray([[ 3.03124   ,  4.307985  ,  4.8191566 , ...,  1.8983663 ,
               5.2370872 ,  7.0021133 ],
             [ 4.3552337 ,  4.9940343 ,  3.9476604 , ..., -0.27066302,
               4.676184  ,  6.327023  ],
             [ 4.5586944 ,  4.7614737 ,  4.5926437 , ...,  2.5451753 ,
               2.3752594 ,  5.524213  ],
             ...,
             [ 3.5816717 ,  5.6939244 ,  3.6891572 , ...,  3.3720498 ,
               3.9762788 ,  6.462904  ],
             [ 3.3687832 ,  5.215034  ,  6.302809  , ...,  4.5297    ,
               5.701255  ,  4.1892476 ],
             [ 3.977677  ,  5.385982  ,  5.650092  , ...,  3.5632966 ,
               5.409166  ,  5.876232  ]], dtype=float32)

In [26]:
# Back to original domain
X_synthetic = index_update(X_synthetic, index[0, :], data_row.ravel())
X_synthetic_orig = sc.inverse_transform(X_synthetic)
X_synthetic_orig

array([[ 1.55000010e+01,  2.10799999e+01,  1.02899994e+02, ...,
         2.13399991e-01,  3.00300002e-01,  1.04800001e-01],
       [ 1.55970020e+01,  2.21307163e+01,  9.76709442e+01, ...,
        -1.82494167e-02,  2.88094133e-01,  1.16350152e-01],
       [ 1.63256378e+01,  2.11001415e+01,  1.13628784e+02, ...,
         1.71608090e-01,  1.46336898e-01,  1.01586953e-01],
       ...,
       [ 1.28267155e+01,  2.52322311e+01,  9.12751923e+01, ...,
         2.27359980e-01,  2.44973794e-01,  1.18848920e-01],
       [ 1.20643167e+01,  2.31100616e+01,  1.55940781e+02, ...,
         3.05414379e-01,  3.51247519e-01,  7.70377442e-02],
       [ 1.42448931e+01,  2.38676071e+01,  1.39791611e+02, ...,
         2.40254775e-01,  3.33252251e-01,  1.08060375e-01]], dtype=float32)

In [27]:
model_pred = clf.predict_proba(X_synthetic_orig)
model_pred

array([[1.  , 0.  ],
       [0.88, 0.12],
       [0.93, 0.07],
       ...,
       [0.94, 0.06],
       [0.88, 0.12],
       [0.89, 0.11]])

In [28]:
print(onp.unique(model_pred[:,0]))
print(onp.unique(model_pred[:,1]))

[0.47 0.52 0.57 0.58 0.59 0.6  0.61 0.62 0.63 0.64 0.65 0.66 0.67 0.68
 0.69 0.7  0.71 0.72 0.73 0.74 0.75 0.76 0.77 0.78 0.79 0.8  0.81 0.82
 0.83 0.84 0.85 0.86 0.87 0.88 0.89 0.9  0.91 0.92 0.93 0.94 0.95 0.96
 0.97 0.98 0.99 1.  ]
[0.   0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1  0.11 0.12 0.13
 0.14 0.15 0.16 0.17 0.18 0.19 0.2  0.21 0.22 0.23 0.24 0.25 0.26 0.27
 0.28 0.29 0.3  0.31 0.32 0.33 0.34 0.35 0.36 0.37 0.38 0.39 0.4  0.41
 0.42 0.43 0.48 0.53]


In [29]:
X_synthetic_disc, all_bins = discretize(X_synthetic_orig, [25, 50, 75], all_bins)
print(X_synthetic_disc.shape)
X_synthetic_disc

(1000, 30)


DeviceArray([[3, 3, 3, ..., 3, 2, 3],
             [3, 3, 3, ..., 0, 2, 3],
             [3, 3, 3, ..., 3, 0, 3],
             ...,
             [1, 3, 2, ..., 3, 0, 3],
             [1, 3, 3, ..., 3, 3, 1],
             [2, 3, 3, ..., 3, 3, 3]], dtype=int32)

In [30]:
X_synthetic_onehot = oe.transform(X_synthetic_disc)
X_synthetic_onehot.shape

(1000, 120)

# Solve

* Get pairwise distances between original data and synthetic neighborhood
* Weight using kernel function
* Solve with ridge regression

In [31]:
distances = scipy.spatial.distance.cdist(X_synthetic[:1], X_synthetic)
distances = distances.reshape(-1, 1)
distances.shape

(1000, 1)

In [32]:
def kernel_fn(distances, kernel_width=onp.sqrt(X_test.shape[1])):
    return onp.sqrt(onp.exp(-(distances ** 2) / kernel_width ** 2))

def kernel_fn_jax(distances, kernel_width=np.sqrt(X_test.shape[1])):
    return np.sqrt(np.exp(-(distances ** 2) / kernel_width ** 2))

weights = kernel_fn(distances).ravel()
weights.shape

(1000,)

In [33]:
solver = Ridge(alpha=1, fit_intercept=True)
solver.fit(X_synthetic_onehot, model_pred[:,0], sample_weight=weights)

Ridge(alpha=1, copy_X=True, fit_intercept=True, max_iter=None, normalize=False,
      random_state=None, solver='auto', tol=0.001)

In [34]:
solver.score(X_synthetic_onehot, model_pred[:, 0], sample_weight=distances.ravel())

0.8382104246909032

In [35]:
solver.predict(X_synthetic_onehot[0].reshape((1, -1)))

array([1.00518069])

In [36]:
solver.coef_

array([-0.01465833, -0.01010554, -0.01213016, -0.00432731, -0.03096207,
       -0.01130233, -0.00082199,  0.00186505, -0.03257237, -0.01411133,
       -0.01285664,  0.01831899, -0.03420101, -0.01399956, -0.01165435,
        0.01863359, -0.00679038, -0.01228653, -0.01411142, -0.008033  ,
       -0.02562364, -0.0100659 , -0.00448154, -0.00105026, -0.02745177,
       -0.03919753, -0.0011945 ,  0.02662247, -0.03499116, -0.02521414,
       -0.00792308,  0.02690705, -0.00532627, -0.01080183, -0.01521744,
       -0.00987581, -0.00251683, -0.00862203, -0.01535109, -0.01473138,
        0.        ,  0.        , -0.01083942, -0.03038192, -0.02430588,
       -0.00676988, -0.00536736, -0.00477822,  0.        ,  0.        ,
        0.        , -0.04122134, -0.01178978,  0.        , -0.0009223 ,
       -0.02850926, -0.01297373, -0.01119055, -0.0072997 , -0.00975736,
       -0.01234457, -0.01268606, -0.00752261, -0.0086681 , -0.03239014,
       -0.00663131,  0.0008185 , -0.00301839, -0.01079389, -0.00

In [37]:
X_synthetic_onehot[0].toarray()

array([[0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 1., 0., 0., 0., 0., 1.]])

In [38]:
importances = solver.coef_[X_synthetic_onehot[0].toarray().ravel() == 1]
importances

array([-0.00432731,  0.00186505,  0.01831899,  0.01863359, -0.008033  ,
       -0.00105026,  0.02662247,  0.02690705, -0.00987581, -0.01473138,
       -0.03038192, -0.00536736, -0.04122134, -0.02850926, -0.00975736,
       -0.0086681 , -0.00301839, -0.01433257, -0.00933138, -0.00649052,
        0.04293765,  0.00023033,  0.07359534,  0.05030532, -0.00457005,
        0.00369254,  0.03259681,  0.06610105, -0.00522634, -0.00807995])

In [39]:
sorted(list(zip(data['feature_names'], importances)), key=lambda x: x[1], reverse=True)[:10]

[('worst perimeter', 0.07359534433430019),
 ('worst concave points', 0.06610105228917687),
 ('worst area', 0.05030532099251261),
 ('worst radius', 0.042937650980115746),
 ('worst concavity', 0.032596814586122694),
 ('mean concave points', 0.026907048416998355),
 ('mean concavity', 0.02662246559710487),
 ('mean area', 0.018633585063711496),
 ('mean perimeter', 0.018318989973178268),
 ('worst compactness', 0.0036925373993616337)]

In [40]:
explainer = LimeTabularExplainer(training_data=X_train, feature_names=data['feature_names'])
explainer

<lime.lime_tabular.LimeTabularExplainer at 0x7f1c221dddd0>

In [41]:
%timeit explainer.explain_instance(data_row=X_test[0], predict_fn=clf.predict_proba, labels=(0,))

exp = explainer.explain_instance(
    data_row=X_test[0],
    predict_fn=clf.predict_proba,
    labels=(0,)
)

1.49 s ± 126 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
exp.as_list(0)

[('worst concave points > 0.17', 0.13279641497505934),
 ('worst area > 1160.50', 0.1301029295944422),
 ('worst perimeter > 127.90', 0.11233780684855998),
 ('worst radius > 19.42', 0.0913769426073239),
 ('mean concavity > 0.14', 0.05182585077119412),
 ('mean concave points > 0.08', 0.04837277736026629),
 ('area error > 48.70', 0.04710720057726893),
 ('worst concavity > 0.40', 0.04301079412609855),
 ('perimeter error > 3.46', 0.026907726421716367),
 ('radius error > 0.50', 0.020634025278632246)]

# Now make it end-to-end

In [51]:
def explain_instance(training_data, data_instance, clf, qs=[25, 50, 75], num_samples=5000, num_features=10):
    # Get training data statistics
    all_bins = onp.percentile(training_data, qs, axis=0).T
    
    # Scale the data
    sc = StandardScaler(with_mean=False)
    sc.fit(training_data)
    data_scaled = sc.transform(data_instance.reshape((1, -1)))
    
    # Create synthetic neighborhood
    X_synthetic = onp.tile(data_scaled, (num_samples, 1))
    X_synthetic = X_synthetic + onp.random.normal(size=(num_samples, training_data.shape[1]))
    X_synthetic[0] = data_scaled.ravel()
    X_synthetic_orig = sc.inverse_transform(X_synthetic)
    X_synthetic_disc, all_bins = discretize(X_synthetic_orig, qs, all_bins)

    # Get model predictions (i.e. groundtruth)
    model_pred = clf.predict_proba(X_synthetic_orig)

    # Solve
    distances = scipy.spatial.distance.cdist(X_synthetic[:1], X_synthetic)
    distances = distances.reshape(-1, 1)
    weights = kernel_fn(distances, kernel_width=training_data.shape[1]).ravel()
    solver = Ridge(alpha=1, fit_intercept=True)
    oe = OneHotEncoder()
    X_synthetic_onehot = oe.fit_transform(X_synthetic_disc)    
    solver.fit(X_synthetic_onehot, model_pred[:,0], sample_weight=weights)
    
    # Explain
    importances = solver.coef_[X_synthetic_onehot[0].toarray().ravel() == 1]
    explanations = sorted(list(zip(data['feature_names'], importances)), 
                          key=lambda x: x[1], reverse=True)[:num_features]
    return explanations

In [52]:
%timeit explain_instance(X_train, X_test[0], clf)

105 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### JAX version of the explainer

In [45]:
def explain_instance_jax(training_data, data_instance, clf, 
                         qs=[25, 50, 75], num_samples=5000, num_features=10):
    # Get training data statistics
    all_bins = np.percentile(training_data, qs, axis=0).T
    
    # Scale the data
    sc = StandardScaler(with_mean=False)
    sc.fit(training_data)
    data_scaled = sc.transform(data_instance.reshape((1, -1)))
    
    # Create synthetic neighborhood
    X_synthetic = np.tile(data_scaled, (num_samples, 1))
    X_synthetic = X_synthetic + random.normal(key, (num_samples, training_data.shape[1]))
    X_synthetic = index_update(X_synthetic, index[0, :], data_scaled.ravel())
    X_synthetic_orig = sc.inverse_transform(X_synthetic)
    X_synthetic_disc, all_bins = discretize_jax(X_synthetic_orig, qs, all_bins)
    oe = OneHotEncoder()
    X_synthetic_onehot = oe.fit_transform(X_synthetic_disc)
    
    # Get model predictions (i.e. groundtruth)
    model_pred = clf.predict_proba(X_synthetic_orig)

    # Solve
    distances = scipy.spatial.distance.cdist(X_synthetic[:1], X_synthetic)
    distances = distances.reshape(-1, 1)
    weights = kernel_fn_jax(distances, kernel_width=training_data.shape[1]).ravel()
    solver = Ridge(alpha=1, fit_intercept=True)
    solver.fit(X_synthetic_onehot, model_pred[:,0], sample_weight=weights)
    
    # Explain
    importances = solver.coef_[X_synthetic_onehot[0].toarray().ravel() == 1]
    explanations = sorted(list(zip(data['feature_names'], importances)), 
                          key=lambda x: x[1], reverse=True)[:num_features]
    return explanations

In [46]:
%timeit explain_instance_jax(X_train, X_test[0], clf)

373 ms ± 52.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
