In [1]:
import torch
import numpy as np
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
tokenizer = DistilBertTokenizerFast.from_pretrained("./sarcasm_distilbert")
model = DistilBertForSequenceClassification.from_pretrained("./sarcasm_distilbert")

In [5]:
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [6]:
def predict(headline):
    enc = tokenizer(
        headline,
        truncation=True,
        padding=True,
        max_length=128,
        return_tensors="pt"
    )

    with torch.no_grad():
        outputs = model(**enc)
        probs = torch.softmax(outputs.logits, dim=1)

    return probs.numpy()

In [7]:
from lime.lime_text import LimeTextExplainer

In [8]:
class_names = ["Not Sarcastic", "Sarcastic"]
explainer = LimeTextExplainer(class_names=class_names)

In [15]:
headline = "Oh great, another Monday morning meeting."

exp = explainer.explain_instance(
    headline,
    predict,
    num_features=8
)

html = exp.as_html() # was getting an error with inline html, so used an html file

with open("lime_explanation.html", "w") as f:
    f.write(html)