# 05Tools: Understanding Model Performance and Fairness with the Language Interpretability Tool (LIT)

The [Language Interpretability Tool (LIT)](https://pair-code.github.io/lit/) helps understand a model behavior across a wide range of inputs.  In this notebook the 05 series models will be evaluated with the LIT tool.  While it might sound like LIT is specific to language models,it actually works well with tabular models!

This notebook will show how to connect the tool to the model for predictions (Vertex AI Endpoint) and load a dataset from BigQuery.

### Prerequisites:
-  At least 1 of the notebooks in this series [05, 05a-05i]

### Conceptual Flow & Workflow
<p align="center">
  <img alt="Conceptual Flow" src="../architectures/slides/05tools_LIT_arch.png" width="45%">
&nbsp; &nbsp; &nbsp; &nbsp;
  <img alt="Workflow" src="../architectures/slides/05tools_LIT_console.png" width="45%">
</p>

---
## Setup

### Package Installs (if needed)

In [29]:
try:
    import lit_nlp
except ImportError:
    print('You need to pip install lit-nlp')
    !pip install lit-nlp -q

### Environment

inputs:

In [30]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

'statmike-mlops-349915'

In [31]:
REGION = 'us-central1'
SERIES = '05'

# source data
BQ_PROJECT = PROJECT_ID
BQ_DATASET = 'fraud'
BQ_TABLE = 'fraud_prepped'

# Model Training
VAR_TARGET = 'Class'
VAR_OMIT = 'transaction_id' # add more variables to the string with space delimiters

packages:

In [32]:
from google.cloud import aiplatform
from google.cloud import bigquery

from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
from lit_nlp.api import model as lit_model
from lit_nlp.components import minimal_targeted_counterfactuals
from lit_nlp import notebook

import numpy as np

clients:

In [33]:
aiplatform.init(project=PROJECT_ID, location=REGION)
bq = bigquery.Client(project=PROJECT_ID)

parameters:

In [34]:
BUCKET = PROJECT_ID

---
## Get Vertex AI Endpoint And Deployed Model

In [35]:
endpoints = aiplatform.Endpoint.list(filter = f"labels.series={SERIES}")
endpoint = endpoints[0]

In [36]:
endpoint.display_name

'05'

In [37]:
model = aiplatform.Model(
    model_name = endpoint.list_models()[0].model+f'@{endpoint.list_models()[0].model_version_id}'
)

In [38]:
model.display_name

'05_05h'

In [39]:
model.versioned_resource_name

'projects/1026793852137/locations/us-central1/models/model_05_05h@1'

In [40]:
model.uri

'gs://statmike-mlops-349915/05/05h/models/20220927230247/6/model'

## Get Data for Model Exploration
Retrive the test data for this series:

In [41]:
test = bq.query(query = f"SELECT * EXCEPT({VAR_TARGET}), {VAR_TARGET} FROM {BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE} WHERE splits='TEST' ORDER BY {VAR_TARGET} DESC").to_dataframe()
test = test[test.columns[~test.columns.isin(VAR_OMIT.split()+['splits'])]]
test.head()

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,85285,-7.030308,3.421991,-9.525072,5.270891,-4.02463,-2.865682,-6.989195,3.791551,-4.62273,...,1.103398,-0.541855,0.036943,-0.355519,0.353634,1.042458,1.359516,-0.272188,0.0,1
1,56887,-0.075483,1.812355,-2.566981,4.127549,-1.628532,-0.805895,-3.390135,1.019353,-2.451251,...,0.794372,0.270471,-0.143624,0.013566,0.634203,0.213693,0.773625,0.387434,5.0,1
2,43369,-3.365319,2.426503,-3.752227,0.276017,-2.30587,-1.961578,-3.029283,-1.674462,0.183961,...,2.070008,-0.512626,-0.248502,0.12655,0.104166,-1.055997,-1.200165,-1.012066,88.0,1
3,143354,1.118331,2.074439,-3.837518,5.44806,0.071816,-1.020509,-1.808574,0.521744,-2.032638,...,0.289861,-0.172718,-0.02191,-0.37656,0.192817,0.114107,0.500996,0.259533,1.0,1
4,93888,-10.040631,6.139183,-12.972972,7.740555,-8.684705,-3.837429,-11.907702,5.833273,-5.731054,...,2.823431,1.153005,-0.567343,0.843012,0.549938,0.113892,-0.307375,0.061631,1.0,1


In [42]:
test.shape

(28502, 31)

In [43]:
val = bq.query(query = f"SELECT * EXCEPT({VAR_TARGET}), {VAR_TARGET} FROM {BQ_PROJECT}.{BQ_DATASET}.{BQ_TABLE} WHERE splits='VALIDATE' ORDER BY {VAR_TARGET} DESC").to_dataframe()
val = val[val.columns[~val.columns.isin(VAR_OMIT.split()+['splits'])]]
val.head()

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,32686,0.287953,1.728735,-1.652173,3.813544,-1.090927,-0.984745,-2.202318,0.555088,-2.033892,...,0.262202,-0.633528,0.092891,0.187613,0.368708,-0.132474,0.576561,0.309843,0.0,1
1,53658,-1.739341,1.344521,-0.534379,3.195291,-0.416196,-1.261961,-2.340991,0.713004,-1.416265,...,0.38318,-0.213952,-0.33664,0.237076,0.246003,-0.044228,0.510729,0.220952,0.0,1
2,146998,-2.06424,2.629739,-0.748406,0.694992,0.418178,1.39252,-1.697801,-6.333065,1.724184,...,6.215514,-1.276909,0.459861,-1.051685,0.209178,-0.319859,0.015434,-0.050117,8.0,1
3,11131,-1.426623,4.141986,-9.804103,6.666273,-4.749527,-2.073129,-10.089931,2.791345,-3.249516,...,1.865679,0.407809,0.605809,-0.769348,-1.746337,0.50204,1.977258,0.711607,1.0,1
4,8878,-2.661802,5.856393,-7.653616,6.379742,-0.060712,-3.13155,-3.10357,1.778492,-3.831154,...,0.734775,-0.435901,-0.384766,-0.286016,1.007934,0.413196,0.280284,0.303937,1.0,1


## Setting Up LIT
At a minimimum, LIT requires a Dataset and Model specification following the requires [Type System](https://github.com/PAIR-code/lit/wiki/api.md#type-system):
- Dataset
    - Use the Type System to describe and format the examples as flat dictionaries using a custom Class
- Model
    - A custom class made up of Python functions that return inputs, outputs, and predictions following the Type System

Common Inputs:

In [44]:
VOCABS = {f'{VAR_TARGET}': ['Not Fraud', 'Fraud']}
VAR_SPECS = test.dtypes.apply(lambda x: x.name).to_dict()

### LIT Dataset
The class specification here for `FraudDataset` is built in a way that automate the specification by leveraging the knowledge that all the fields except the target, stored in `VAR_TARGET` above, are all numeric. The class also uses the commmon inputs defined in the previous cell/section. 

Define Class for Dataset:

In [45]:
class FraudDataset(lit_dataset.Dataset):
    
    def __init__(self, ds):
        records = ds.to_dict(orient='records')
        self._examples = []
        for rec in records:
            rec[f'{VAR_TARGET}'] = VOCABS[f'{VAR_TARGET}'][rec[f'{VAR_TARGET}']]
            self._examples.append(rec)
            
    def spec(self):
        specs = VAR_SPECS.copy()
        for s in specs:
            if s == VAR_TARGET: specs[s] = lit_types.CategoryLabel(vocab = VOCABS[VAR_TARGET])
            else: specs[s] = lit_types.Scalar()        
        return specs

Test Class:

In [46]:
test_ds = FraudDataset(test)
val_ds = FraudDataset(val)

In [47]:
len(val_ds.examples), len(test_ds.examples)

(28244, 28502)

In [48]:
#test_ds.examples[0], test_ds.spec()

In [49]:
#val_ds.examples[0], val_ds.spec()

### LIT Model
The class specification here for `FraudModel` is built to use the Vertex AI endpoint as a prediction service via the `endpoint.predict` method.

In [50]:
class FraudModel(lit_model.Model):
    
    def __init__(self, endpoint):
        self.model = endpoint

    def max_minibatch_size(self):
        return 2000        
        
    def predict_minibatch(self, inputs):
        instances = [json_format.ParseDict({key:value for key, value in example.items() if key != VAR_TARGET}, Value()) for example in inputs]
        predictions = endpoint.predict(instances = instances).predictions
        
        # unpack the batch
        for pred in predictions:
            output = {'predictions': pred}
            yield output
        #return [{'predictions': out} for out in predictions]
        
    def input_spec(self):
        specs = VAR_SPECS.copy()
        specs.pop(VAR_TARGET, None) # remove the target being predicted
        for s in specs:
            specs[s] = lit_types.Scalar()
        return specs
    
    def output_spec(self):
        return {
            'predictions': lit_types.MulticlassPreds(
                parent = f'{VAR_TARGET}', vocab = VOCABS[VAR_TARGET]
            )
        }

In [51]:
widget = notebook.LitWidget(
    models = {'classification': FraudModel(endpoint)},
    datasets = {'test': FraudDataset(test[0:4000]), 'val': FraudDataset(val[0:4000])},
    generators = {'Minimal Targeted Counterfactuals': minimal_targeted_counterfactuals.TabularMTC()},
    height = 800
)

### LIT Widget

In [None]:
widget.render()

In [53]:
widget.stop()

---
## Example Screenshot

<img src="../architectures/notebooks/05/lit.png">