# Running LIT with an AutoML NL model

- The <a href="https://zenodo.org/record/2562939#.YGLt2-jYrik"> dataset</a> used for the AutoML NL model contains the gold standard corpus for training the systems aiming at classifying or triaging the documents (PubMed abstracts) in relevant or not relevant for TF-TG relations.
- You may download and then upload <a href="https://github.com/pepecura/ml-on-gcp-explainable-ai-xai/blob/main/lit-for-vertex-automl/text-small-sample.csv"> the small sample file on GitHub </a> to this notebook environment to test LIT. 

Prerequisites: </br>
Build an AutoML NL model on Vertex AI and deploy it to end endpoint. Set the parameters in this notebook for PROJECT_ID, ENDPOINT_ID, LOCATION_ID of the AutoML model. </br>
The data model for detecting TF-TG relations will include two variables, "text" and "label" (0/1).

In [1]:
# Install LIT
!pip install --quiet lit-nlp

In [1]:
# Import packages.
from lit_nlp.api import dataset
from lit_nlp.api import model
from lit_nlp.api import types as lit_types
import requests
import json
import pandas as pd

# Setup URL and headers for prediction request.
PROJECT_ID  = "XXX"          # set @param e.g. "google.com:automl-xai"
ENDPOINT_ID = "XXX"          # set @param e.g. "1112353923588423680"
LOCATION_ID = "XXX"          # set @param e.g. "us-central1"
url = f'https://{LOCATION_ID}-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/{LOCATION_ID}/endpoints/{ENDPOINT_ID}:predict'
token = !gcloud auth print-access-token
headers = {'content-type': 'application/json', 'Authorization': f'Bearer {token[0]}'}

In [7]:
PROJECT_ID = "google.com:automl-xai"
ENDPOINT_ID = "1112353923588423680"
LOCATION_ID = "us-central1"
url = f'https://{LOCATION_ID}-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/{LOCATION_ID}/endpoints/{ENDPOINT_ID}:predict'
token = !gcloud auth print-access-token
headers = {'content-type': 'application/json', 'Authorization': f'Bearer {token[0]}'}

In [8]:
# Read the data into a dataframe and define the data specification for LIT.
class ReadData(dataset.Dataset):
  def __init__(self, path: str):
    with open(path) as fd:
      df = pd.read_csv(fd, header=0)
    self._examples = [{
        "text": row["text"],
        "label": str(row["label"]),
    } for _, row in df.iterrows()]

  def spec(self) -> lit_types.Spec:
    return {
        "text": lit_types.TextSegment(),
        "label": lit_types.CategoryLabel(vocab=["0", "1"]),
    }

# Get online predictions.
class RunModel(model.Model):
  def input_spec(self) -> lit_types.Spec:
    return {
        "text": lit_types.TextSegment(),
        "label": lit_types.CategoryLabel(vocab=["0", "1"], required=False),
    }

  def output_spec(self) -> lit_types.Spec:
    return {
        "preds": lit_types.MulticlassPreds(vocab=["0", "1"], parent="label", null_idx=0),
    }

  def predict_minibatch(self, examples):
    # Online prediction predictions in a loop.
    def get_pred(ex):
      # Escape quotes in text entries, to be able to send in payload.
      text = json.dumps(ex["text"])
      payload = '{"instances": {"mimeType": "text/plain","content": ' + text + ' }}'
      r = requests.post(url, data=payload, headers=headers)
      return r.json()['predictions'][0]['confidences']
    return [{"preds": get_pred(ex)} for ex in examples]

In [11]:
# Create the LIT widget with the model and dataset to analyze.
from lit_nlp import notebook

datasets = {'data': ReadData('text-small-sample.csv')}
models   = {'auto_nl': RunModel()}
widget   = notebook.LitWidget(models, datasets, height=800)

In [12]:
# Render the widget
widget.render()

127.0.0.1 - - [24/Mar/2022 16:22:03] "GET / HTTP/1.1" 200 1406
127.0.0.1 - - [24/Mar/2022 16:22:04] "GET /main.js HTTP/1.1" 200 1809942
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_info HTTP/1.1" 200 15752
127.0.0.1 - - [24/Mar/2022 16:22:12] "GET /static/favicon.png HTTP/1.1" 200 13257
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_dataset?dataset_name=data HTTP/1.1" 200 1893
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_preds?model=auto_nl&dataset_name=data&requested_types=MulticlassPreds HTTP/1.1" 200 551
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_interpretations?model=auto_nl&dataset_name=data&interpreter=metrics HTTP/1.1" 200 252
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_preds?model=auto_nl&dataset_name=data&requested_types=MulticlassPreds HTTP/1.1" 200 551
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_interpretations?model=auto_nl&dataset_name=data&interpreter=metrics HTTP/1.1" 200 252
127.0.0.1 - - [24/Mar/2022 16:22:12] "POST /get_preds?model=auto_nl&dataset_