<a href="https://colab.research.google.com/github/xanasa14/MLImplementations/blob/master/Eli5Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from sklearn.datasets import fetch_20newsgroups

categories = ['alt.atheism', 'soc.religion.christian',
              'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=('headers', 'footers'),
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=('headers', 'footers'),
)

Downloading 20news dataset. This may take a few minutes.
Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)


In [3]:
!pip install eli5

Collecting eli5
[?25l  Downloading https://files.pythonhosted.org/packages/97/2f/c85c7d8f8548e460829971785347e14e45fa5c6617da374711dec8cb38cc/eli5-0.10.1-py2.py3-none-any.whl (105kB)
[K     |███                             | 10kB 17.1MB/s eta 0:00:01[K     |██████▏                         | 20kB 20.1MB/s eta 0:00:01[K     |█████████▎                      | 30kB 9.8MB/s eta 0:00:01[K     |████████████▍                   | 40kB 8.5MB/s eta 0:00:01[K     |███████████████▌                | 51kB 4.3MB/s eta 0:00:01[K     |██████████████████▋             | 61kB 4.5MB/s eta 0:00:01[K     |█████████████████████▊          | 71kB 4.9MB/s eta 0:00:01[K     |████████████████████████▊       | 81kB 5.3MB/s eta 0:00:01[K     |███████████████████████████▉    | 92kB 5.3MB/s eta 0:00:01[K     |███████████████████████████████ | 102kB 4.3MB/s eta 0:00:01[K     |████████████████████████████████| 112kB 4.3MB/s 
Installing collected packages: eli5
Successfully installed eli5-0.10.1


In [4]:
from sklearn.base import BaseEstimator, TransformerMixin
from keras.models import Model, Input
from keras.layers import Dense, LSTM, Dropout, Embedding, SpatialDropout1D, Bidirectional, concatenate
from keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import accuracy_score
from eli5.lime import TextExplainer
import regex as re
import numpy as np


class KerasTextClassifier(BaseEstimator, TransformerMixin):
    '''Wrapper class for keras text classification models that takes raw text as input.'''
    
    def __init__(self, max_words=30000, input_length=100, emb_dim=20, n_classes=4, epochs=5, batch_size=32):
        self.max_words = max_words
        self.input_length = input_length
        self.emb_dim = emb_dim
        self.n_classes = n_classes
        self.epochs = epochs
        self.bs = batch_size
        self.model = self._get_model()
        self.tokenizer = Tokenizer(num_words=self.max_words+1,
                                   lower=True, split=' ', oov_token="UNK")
    
    def _get_model(self):
        input_text = Input((self.input_length,))
        text_embedding = Embedding(input_dim=self.max_words + 2, output_dim=self.emb_dim,
                                   input_length=self.input_length, mask_zero=False)(input_text)
        text_embedding = SpatialDropout1D(0.5)(text_embedding)
        bilstm = Bidirectional(LSTM(units=32, return_sequences=True, recurrent_dropout=0.5))(text_embedding)
        x = concatenate([GlobalAveragePooling1D()(bilstm), GlobalMaxPooling1D()(bilstm)])
        x = Dropout(0.7)(x)
        x = Dense(512, activation="relu")(x)
        x = Dropout(0.6)(x)
        x = Dense(512, activation="relu")(x)
        x = Dropout(0.5)(x)
        out = Dense(units=self.n_classes, activation="softmax")(x)
        model = Model(input_text, out)
        model.compile(optimizer="adam",
                      loss="sparse_categorical_crossentropy",
                      metrics=["accuracy"])
        return model
    
    def _get_sequences(self, texts):
        seqs = self.tokenizer.texts_to_sequences(texts)
        return pad_sequences(seqs, maxlen=self.input_length, value=0)
    
    def _preprocess(self, texts):
        return [re.sub(r"\d", "DIGIT", x) for x in texts]
    
    def fit(self, X, y):
        '''
        Fit the vocabulary and the model.
        
        :params:
        X: list of texts.
        y: labels.
        '''
        
        self.tokenizer.fit_on_texts(self._preprocess(X))
        self.tokenizer.word_index = {e: i for e,i in self.tokenizer.word_index.items() if i <= self.max_words}
        self.tokenizer.word_index[self.tokenizer.oov_token] = self.max_words + 1
        seqs = self._get_sequences(self._preprocess(X))
        self.model.fit(seqs, y, batch_size=self.bs, epochs=self.epochs, validation_split=0.1)
    
    def predict_proba(self, X, y=None):
        seqs = self._get_sequences(self._preprocess(X))
        return self.model.predict(seqs)
    
    def predict(self, X, y=None):
        return np.argmax(self.predict_proba(X), axis=1)
    
    def score(self, X, y):
        y_pred = self.predict(X)
        return accuracy_score(y, y_pred)



In [5]:
text_model = KerasTextClassifier(epochs=20, max_words=20000, input_length=200)


In [6]:
text_model.fit(twenty_train.data, twenty_train.target)


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [7]:
text_model.score(twenty_test.data, twenty_test.target)


0.7443408788282291

In [8]:
doc = twenty_test.data[2]
te = TextExplainer(random_state=42)
te.fit(doc, text_model.predict_proba)
te.show_prediction(target_names=twenty_train.target_names)

Contribution?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Contribution?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Contribution?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Contribution?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+1.262,writes,,
+0.899,of,,
+0.843,vegetarian,,
+0.752,absolute,,
+0.664,are,,
+0.613,tend,,
+0.589,clarification,,
+0.556,therefore,,
+0.513,absolute knowledge,,
+0.510,remember,,

Contribution?,Feature
1.262,writes
0.899,of
0.843,vegetarian
0.752,absolute
0.664,are
0.613,tend
0.589,clarification
0.556,therefore
0.513,absolute knowledge
0.51,remember

Contribution?,Feature
0.575,are not
0.564,have a
0.475,are vegetarian
0.449,one of
0.421,ac uk
0.229,okstate edu
0.168,but my
0.154,sting centipedes
0.146,ed ac
0.131,painful sting

Contribution?,Feature
1.055,a
0.997,t
0.985,have
0.961,g
0.932,sting
0.828,but
0.595,medical
0.575,painful
0.556,vms
0.532,on

Contribution?,Feature
0.848,i would
0.623,one of
0.571,have a
0.567,ocom okstate
0.503,ed ac
0.45,are vegetarian
0.432,ac uk
0.429,okstate edu
0.416,vms ocom
0.38,are not
