# Model explainability

Neural/transformers based approaches have recently made amazing advancements. But it is difficult to understand how models make decisions. This is especially important in sensitive areas where models are being used to drive critical decisions.

This notebook will cover how to gain a level of understanding of complex natural language model outputs.

# Install dependencies

Install `txtai` and all dependencies.

In [None]:
%%capture
!pip install git+https://github.com/neuml/txtai#egg=txtai[pipeline] shap

# Semantic Search

The first example we'll cover is semantic search. Semantic search applications have an understanding of natural language and identify results that have the same meaning, not necessarily the same keywords. While this produces higher quality results, one advantage of keyword search is it's easy to understand why a result why selected. The keyword is there.

Let's see if we can gain a better understanding of semantic search output. 

In [None]:
from txtai.embeddings import Embeddings

data = ["US tops 5 million confirmed virus cases",
        "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
        "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
        "The National Park Service warns against sacrificing slower friends in a bear attack",
        "Maine man wins $1M from $25 lottery ticket",
        "Make huge profits without work, earn up to $100,000 a day"]

# Create embeddings index with content enabled. The default behavior is to only store indexed vectors.
embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": True})

# Create an index for the list of text
embeddings.index([(uid, text, None) for uid, text in enumerate(data)])

# Run a search
embeddings.explain("feel good story", limit=1)

[{'id': '4',
  'score': 0.08329004049301147,
  'text': 'Maine man wins $1M from $25 lottery ticket',
  'tokens': [('Maine', 0.003297939896583557),
   ('man', -0.03039500117301941),
   ('wins', 0.03406312316656113),
   ('$1M', -0.03121592104434967),
   ('from', -0.02270638197660446),
   ('$25', 0.012891143560409546),
   ('lottery', -0.015372440218925476),
   ('ticket', 0.007445111870765686)]}]

The `explain` method above ran an embeddings query like `search` but also analyzed each token to determine term importance. Looking at the results, it appears that `win` is the most important term. Let's visualize it.

In [None]:
from IPython.display import HTML

def plot(query):
  result = embeddings.explain(query, limit=1)[0]

  output = f"<b>{query}</b><br/>"
  spans = []
  for token, score in result["tokens"]:
    color = None
    if score >= 0.1:
      color = "#fdd835"
    elif score >= 0.075:
      color = "#ffeb3b"
    elif score >= 0.05:
      color = "#ffee58"
    elif score >= 0.02:
      color = "#fff59d"

    spans.append((token, score, color))

  if result["score"] >= 0.05 and not [color for _, _, color in spans if color]:
    mscore = max([score for _, score, _ in spans])
    spans = [(token, score, "#fff59d" if score == mscore else color) for token, score, color in spans]

  for token, _, color in spans:
    if color:
      output += f"<span style='background-color: {color}'>{token}</span> "
    else:
      output += f"{token} "

  return output

HTML(plot("feel good story"))

Let's try some more queries!

In [None]:
output = ""
for query in ["feel good story", "climate change", "public health story", "war", "wildlife", "asia", "lucky", "dishonest junk"]:
  output += plot(query) + "<br/><br/>"

HTML(output)

There is also a batch method that can run bulk explainations more efficently. 

In [None]:
queries = ["feel good story", "climate change", "public health story", "war", "wildlife", "asia", "lucky", "dishonest junk"]
results = embeddings.batchexplain(queries, limit=1)

for x, result in enumerate(results):
  print(result)

[{'id': '4', 'text': 'Maine man wins $1M from $25 lottery ticket', 'score': 0.08329004049301147, 'tokens': [('Maine', 0.003297939896583557), ('man', -0.03039500117301941), ('wins', 0.03406312316656113), ('$1M', -0.03121592104434967), ('from', -0.02270638197660446), ('$25', 0.012891143560409546), ('lottery', -0.015372440218925476), ('ticket', 0.007445111870765686)]}]
[{'id': '1', 'text': "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 'score': 0.24478264153003693, 'tokens': [("Canada's", -0.026454076170921326), ('last', 0.017057165503501892), ('fully', 0.007285907864570618), ('intact', -0.005608782172203064), ('ice', 0.009459629654884338), ('shelf', -0.029393181204795837), ('has', 0.0253918319940567), ('suddenly', 0.021642476320266724), ('collapsed,', -0.030680224299430847), ('forming', 0.01910528540611267), ('a', -0.00890059769153595), ('Manhattan-sized', -0.023612067103385925), ('iceberg', -0.009710296988487244)]}]
[{'id': '0', 'text':

Of course, this method is supported through YAML-based applications and the API.

In [None]:
from txtai.app import Application

app = Application("""
writable: true
embeddings:
  path: sentence-transformers/nli-mpnet-base-v2
  content: true
""")

app.add([{"id": uid, "text": text} for uid, text in enumerate(data)])
app.index()

app.explain("feel good story", limit=1)

[{'id': '4',
  'score': 0.08329004049301147,
  'text': 'Maine man wins $1M from $25 lottery ticket',
  'tokens': [('Maine', 0.003297939896583557),
   ('man', -0.03039500117301941),
   ('wins', 0.03406312316656113),
   ('$1M', -0.03121592104434967),
   ('from', -0.02270638197660446),
   ('$25', 0.012891143560409546),
   ('lottery', -0.015372440218925476),
   ('ticket', 0.007445111870765686)]}]

# Pipeline models

txtai pipelines are wrappers around Hugging Face pipelines with logic to easily integrate with txtai's workflow framework. Given that, we can use the [SHAP](https://github.com/slundberg/shap) library to explain predictions.

Let's try a sentiment analysis example.

In [None]:
import shap

from txtai.pipeline import Labels

data = ["Dodgers lose again, give up 3 HRs in a loss to the Giants",
        "Massive dunk!!! they are now up by 15 with 2 minutes to go"]

labels = Labels(dynamic=False)

# explain the model on two sample inputs
explainer = shap.Explainer(labels.pipeline) 
shap_values = explainer(data)

In [None]:
shap.plots.text(shap_values[0, :, "NEGATIVE"])

In [None]:
shap.plots.text(shap_values[1, :, "NEGATIVE"])

The [SHAP documentation](https://shap.readthedocs.io/en/latest/text_examples.html) provides a great list of additional examples for translation, text generation, summarization, translation and question-answering.

The SHAP library is pretty 🔥🔥🔥 Check it out for more!

# Wrapping up

This notebook briefly introduced model explainability. There is a lot of work in this area, expect a number of different methods to become available. Model explainability helps users gain a level of trust in model predictions. It also helps debug why a model is making a decision, which can potentially drive how to fine-tune a model to make better predictions. 

Keep an eye on this important area over the coming months!
