In [1]:
import re
from functools import partial

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

import shap
from bpemb import BPEmb

from TextCNN.modules.data.dataset import load_goodreads_dataset, load_IMDB_dataset
from TextCNN.modules.pipeline_config.utils import set_split, pad_collate
from TextCNN.modules.model.model import TextCNN

In [2]:
device = torch.device('cuda')
BPEMB_EN = BPEmb(lang='en', dim=300, vs=10000, add_pad_emb=True)
BPEMB_EN.vocab_size = len(BPEMB_EN.emb.index_to_key)
BPEMB_EN.vs = len(BPEMB_EN.emb.index_to_key)
PAD_IDX = BPEMB_EN.vs - 1

In [3]:
c = torch.load('CNN_checkpoints/checkpoint_epoch_5.pt', map_location=device)

In [4]:
batch_size = c['batch_size_training']
collator = partial(pad_collate, pad_value=PAD_IDX)

In [5]:
reviews_df = pd.read_feather('../Datasets/spoiler_reviews_v3.ftr')

In [6]:
data = set_split(
    reviews_df, BPEMB_EN, test_size=0.2, 
    batch_size=batch_size, train_shuffle=False, collator=collator
)

In [7]:
model = TextCNN(BPEMB_EN, hid_size=128).to(device)
model.load_state_dict(c['model_state_dict'])
model.eval();

## Explain

In [8]:
def custom_tokenizer(s, return_offsets_mapping=True):
    """ Custom tokenizers conform to a subset of the transformers API.
    """
    pos = 0
    offset_ranges = []
    input_ids = []
    
    for m in re.finditer(r"\W", s):
        start, end = m.span(0)
        offset_ranges.append((pos, start))
        input_ids.append(s[pos:start])
        pos = end
        
    if pos != len(s):
        offset_ranges.append((pos, len(s)))
        input_ids.append(s[pos:])
        
    out = {}
    out["input_ids"] = input_ids
    
    if return_offsets_mapping:
        out["offset_mapping"] = offset_ranges
        
    return out

In [9]:
def custom_tokenizer(s, return_offsets_mapping=True):
    """ Custom tokenizers conform to a subset of the transformers API.
    """
    start_pos = end_pos = 0
    offset_ranges = []
    
    input_ids = s.split()
        
    for token in input_ids:
        end_pos = start_pos + len(token)
        offset_ranges.append((start_pos, end_pos))
        start_pos = end_pos + 1
        
    out = {}
    out["input_ids"] = input_ids
    
    if return_offsets_mapping:
        out["offset_mapping"] = offset_ranges
        
    return out

In [10]:
def tokenize_with_pad(texts):
    tokenized = [torch.tensor(BPEMB_EN.encode_ids(x)) for x in texts]
    texts_tokens_lengths = torch.tensor([len(x) for x in tokenized])
    max_len = torch.max(texts_tokens_lengths)
    
    max_len = max_len if max_len <= 512 else 512
    
    # Extend to even max_len.  
    additive = (max_len % 2)
    max_len += additive
    lengths_to_pad = max_len - texts_tokens_lengths - additive
    
    result = torch.stack([
        F.pad(x, pad=(0, val_to_pad), value=PAD_IDX) 
        for x, val_to_pad in zip(tokenized, lengths_to_pad)
    ]).type(torch.int64)
    
    return result.to(device)

In [11]:
def processing(text):
    enc = tokenize_with_pad(text)
    
    #attn_mask = (enc != PAD_IDX).type(torch.int64)

    out = torch.sigmoid(model(enc)).detach().cpu().numpy()
    return out

In [12]:
masker = shap.maskers.Text(custom_tokenizer)

In [13]:
explainer = shap.Explainer(processing, masker)

In [14]:
texts = reviews_df['review_text']
sent = texts.iloc[0]

In [19]:
sent_spoiler = reviews_df.loc[reviews_df['is_spoiler'] == 1, 'review_text'].iloc[0]

In [20]:
shap_values = explainer([sent_spoiler])

In [21]:
processing([sent_spoiler])

array([[0.5230182]], dtype=float32)

In [22]:
shap.plots.text(shap_values)