# 05tools_2b: Understanding Model Performance and Fairness with the Language Interpretability Tool (LIT)

The [Language Interpretability Tool (LIT)](https://pair-code.github.io/lit/) helps understand a model behavior across a wide range of inputs.  In this notebook the 05 series models will be evaluated with the LIT tool.  While it might sound like LIT is specific to language models,it actually work well with any model!

This notebook will show how to connect the tool to the model multiple ways and how to use the tool to evaluate the model.


---
## Installing LIT
The first time using LIT in the notebook environment will require an install of the python package lit-nlp:
- From a Terminal: `pip install lit-nlp -U -q`
- From a Cell of User-Managed Workbench: `!pip install lit-nlp -U -q`
- From a Cell of a Managed Workbench: `!pip install --user lit-nlp -U -q`

---
## Environment Setup

inputs:

In [1]:
project = !gcloud config get-value project
PROJECT_ID = project[0]
PROJECT_ID

'statmike-mlops-349915'

In [2]:
REGION = 'us-central1'
DATANAME = 'fraud'
NOTEBOOK = '05tools_2b'
SERIES = '05'

# Model Training
VAR_TARGET = 'Class'
VAR_OMIT = 'transaction_id' # add more variables to the string with space delimiters

packages:

In [3]:
from google.cloud import aiplatform
from google.cloud import bigquery

from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
from lit_nlp.api import model as lit_model
from lit_nlp import notebook

import numpy as np

clients:

In [4]:
aiplatform.init(project=PROJECT_ID, location=REGION)
bq = bigquery.Client(project=PROJECT_ID)

parameters:

In [5]:
BUCKET = PROJECT_ID
DIR = f"temp/{NOTEBOOK}"

environment:

In [6]:
!rm -rf {DIR}
!mkdir -p {DIR}

---
## Get Vertex AI Endpoint And Deployed Model

In [7]:
endpoints = aiplatform.Endpoint.list(filter = f"display_name={SERIES}_{DATANAME}")
endpoint = endpoints[0]

In [8]:
endpoint.display_name

'05_fraud'

In [9]:
model = aiplatform.Model(
    model_name = endpoint.list_models()[0].model+f'@{endpoint.list_models()[0].model_version_id}'
)

In [10]:
model.display_name

'05i_fraud'

In [11]:
model.versioned_resource_name

'projects/1026793852137/locations/us-central1/models/model_05i_fraud@1'

In [12]:
model.uri

'gs://statmike-mlops-349915/fraud/models/05i/20220728003419/18/model'

## Get Data for Model Exploration
Retrive the test data for this series:

In [13]:
test = bq.query(query = f"SELECT * EXCEPT({VAR_TARGET}), {VAR_TARGET} FROM {DATANAME}.{DATANAME}_prepped WHERE splits='TEST' ORDER BY {VAR_TARGET} DESC").to_dataframe()
test = test[test.columns[~test.columns.isin(VAR_OMIT.split()+['splits'])]]
test.head()

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,148074,-2.219219,0.727831,-5.45823,5.92485,3.932464,-3.085984,-1.67787,0.865075,-3.17726,...,0.417472,-0.817343,-0.028752,0.025723,-0.825835,-0.013089,0.413291,-0.131387,0.0,1
1,129668,0.753356,2.284988,-5.164492,3.831112,-0.073622,-1.316596,-1.855495,0.831079,-1.567514,...,0.382007,0.033958,0.187697,0.358433,-0.488934,-0.258802,0.296145,-0.047174,2.0,1
2,56887,-0.075483,1.812355,-2.566981,4.127549,-1.628532,-0.805895,-3.390135,1.019353,-2.451251,...,0.794372,0.270471,-0.143624,0.013566,0.634203,0.213693,0.773625,0.387434,5.0,1
3,146998,-2.06424,2.629739,-0.748406,0.694992,0.418178,1.39252,-1.697801,-6.333065,1.724184,...,6.215514,-1.276909,0.459861,-1.051685,0.209178,-0.319859,0.015434,-0.050117,8.0,1
4,78725,-4.312479,1.886476,-2.338634,-0.475243,-1.185444,-2.112079,-2.122793,0.272565,0.290273,...,0.550541,-0.06787,-1.114692,0.269069,-0.020572,-0.963489,-0.918888,0.001454,60.0,1


In [14]:
test.shape

(28522, 31)

In [15]:
val = bq.query(query = f"SELECT * EXCEPT({VAR_TARGET}), {VAR_TARGET} FROM {DATANAME}.{DATANAME}_prepped WHERE splits='VALIDATE' ORDER BY {VAR_TARGET} DESC").to_dataframe()
val = val[val.columns[~val.columns.isin(VAR_OMIT.split()+['splits'])]]
val.head()

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,85285,-7.030308,3.421991,-9.525072,5.270891,-4.02463,-2.865682,-6.989195,3.791551,-4.62273,...,1.103398,-0.541855,0.036943,-0.355519,0.353634,1.042458,1.359516,-0.272188,0.0,1
1,142961,0.457845,1.373769,-0.488926,2.805351,1.777386,0.100492,1.295016,-0.135857,-1.695822,...,0.105593,0.371014,0.051105,0.401524,-0.724766,-0.202881,0.092124,0.094956,0.0,1
2,149640,0.754316,2.379822,-5.137274,3.818392,0.043203,-1.285451,-1.766684,0.756711,-1.765722,...,0.397058,0.141165,0.171985,0.394274,-0.444642,-0.263189,0.304703,-0.044362,2.0,1
3,92092,-1.108478,3.448953,-6.216972,3.021052,-0.529901,-2.551375,-2.001743,1.092432,-0.836098,...,0.825951,1.14417,0.208559,-0.295497,-0.690232,-0.364749,0.229327,0.20883,18.0,1
4,94362,-26.457745,16.497472,-30.177317,8.904157,-17.8926,-1.227904,-31.197329,-11.43892,-9.462573,...,-8.755698,3.460893,0.896538,0.254836,-0.738097,-0.966564,-7.263482,-1.324884,1.0,1


## Setting Up LIT
At a minimimum, LIT requires a Dataset and Model specification following the requires [Type System](https://github.com/PAIR-code/lit/wiki/api.md#type-system):
- Dataset
    - Use the Type System to describe and format the examples as flat dictionaries using a custom Class
- Model
    - A custom class made up of Python functions that return inputs, outputs, and predictions following the Type System

Common Inputs:

In [34]:
VOCABS = {f'{VAR_TARGET}': ['Not Fraud', 'Fraud']}
VAR_SPECS = test.dtypes.apply(lambda x: x.name).to_dict()

### LIT Dataset
The class specification here for `FraudDataset` is built in a way that automate the specification by leveraging the knowledge that all the fields except the target, stored in `VAR_TARGET` above, are all numeric. The class also uses the commmon inputs defined in the previous cell/section. 

Define Class for Dataset:

In [35]:
class FraudDataset(lit_dataset.Dataset):
    
    def __init__(self, ds):
        records = ds.to_dict(orient='records')
        self._examples = []
        for rec in records:
            rec[f'{VAR_TARGET}'] = VOCABS[f'{VAR_TARGET}'][rec[f'{VAR_TARGET}']]
            self._examples.append(rec)
            
    def spec(self):
        specs = VAR_SPECS.copy()
        for s in specs:
            if s == VAR_TARGET: specs[s] = lit_types.CategoryLabel(vocab = VOCABS[VAR_TARGET])
            else: specs[s] = lit_types.Scalar()        
        return specs

Test Class:

In [36]:
test_ds = FraudDataset(test)
val_ds = FraudDataset(val)

In [37]:
len(val_ds.examples), len(test_ds.examples)

(28461, 28522)

In [38]:
test_ds.examples[0], test_ds.spec()

({'Time': 148074,
  'V1': -2.21921860215056,
  'V2': 0.7278314111063922,
  'V3': -5.45822994652182,
  'V4': 5.92484984705884,
  'V5': 3.9324638237634395,
  'V6': -3.0859842366267003,
  'V7': -1.67786998770016,
  'V8': 0.865074610405235,
  'V9': -3.1772602889458597,
  'V10': -3.4192073840566404,
  'V11': 3.6931739422441203,
  'V12': -3.97843975507806,
  'V13': -1.71859087457346,
  'V14': -8.636297393652589,
  'V15': -0.24296482145526502,
  'V16': 1.17488417316765,
  'V17': 2.13460635695284,
  'V18': 2.59436483300614,
  'V19': -1.25758897993879,
  'V20': 0.9647718037347099,
  'V21': 0.41747174595057,
  'V22': -0.8173433840569749,
  'V23': -0.0287524020141088,
  'V24': 0.0257225108657227,
  'V25': -0.8258353432218559,
  'V26': -0.0130890304987416,
  'V27': 0.413291188715315,
  'V28': -0.131387346404896,
  'Amount': 0.0,
  'Class': 'Fraud'},
 {'Time': Scalar(required=True, annotated=False, min_val=0, max_val=1, default=0, step=0.01),
  'V1': Scalar(required=True, annotated=False, min_val=0

In [39]:
val_ds.examples[0], val_ds.spec()

({'Time': 85285,
  'V1': -7.03030814445441,
  'V2': 3.4219909046755297,
  'V3': -9.52507177254752,
  'V4': 5.27089100906596,
  'V5': -4.02463027558805,
  'V6': -2.86568161775739,
  'V7': -6.989194734394459,
  'V8': 3.7915509375591294,
  'V9': -4.62273033596451,
  'V10': -8.40966487562735,
  'V11': 6.30904400603177,
  'V12': -8.57676143258937,
  'V13': 0.24674671692986203,
  'V14': -11.534046018150802,
  'V15': -0.36426513875870004,
  'V16': -5.45249465771382,
  'V17': -11.8875700201872,
  'V18': -3.5635848100701097,
  'V19': 0.8760187681566278,
  'V20': 0.545698040621445,
  'V21': 1.10339774484256,
  'V22': -0.541854751589521,
  'V23': 0.0369432219896495,
  'V24': -0.355519004066217,
  'V25': 0.35363438209700004,
  'V26': 1.04245799282131,
  'V27': 1.35951563156376,
  'V28': -0.272188101257294,
  'Amount': 0.0,
  'Class': 'Fraud'},
 {'Time': Scalar(required=True, annotated=False, min_val=0, max_val=1, default=0, step=0.01),
  'V1': Scalar(required=True, annotated=False, min_val=0, max_

### LIT Model
The class specification here for `FraudModel` is built to user the Vertex AI endpoint as a prediction service via the `endpoint.predict` method.

In [108]:
class FraudModel(lit_model.Model):
    
    def __init__(self, endpoint):
        self.model = endpoint

    def max_minibatch_size(self):
        return 2000        
        
    def predict_minibatch(self, inputs):
        instances = [json_format.ParseDict({key:value for key, value in example.items() if key != VAR_TARGET}, Value()) for example in inputs]
        predictions = endpoint.predict(instances = instances).predictions
        
        for pred in predictions:
            output = {'predictions': pred}
            yield output
        #return [{'predictions': out} for out in predictions]
        
    def input_spec(self):
        specs = VAR_SPECS.copy()
        specs.pop(VAR_TARGET, None) # remove the target being predicted
        for s in specs:
            specs[s] = lit_types.Scalar()
        return specs
    
    def output_spec(self):
        return {
            'predictions': lit_types.MulticlassPreds(
                parent = f'{VAR_TARGET}', vocab = VOCABS[VAR_TARGET]
            )
        }

In [109]:
widget = notebook.LitWidget(
    models = {'classification': FraudModel(endpoint)},
    datasets = {'test': FraudDataset(test), 'val': FraudDataset(val)},
    height = 800
)

In [None]:
widget.render()

In [None]:
    def max_minibatch_size(self):
        return 1000

Try 1: Dictionaries

In [101]:
newobs_dicts = test.to_dict(orient='records')
newobs_dicts[0]

{'Time': 148074,
 'V1': -2.21921860215056,
 'V2': 0.7278314111063922,
 'V3': -5.45822994652182,
 'V4': 5.92484984705884,
 'V5': 3.9324638237634395,
 'V6': -3.0859842366267003,
 'V7': -1.67786998770016,
 'V8': 0.865074610405235,
 'V9': -3.1772602889458597,
 'V10': -3.4192073840566404,
 'V11': 3.6931739422441203,
 'V12': -3.97843975507806,
 'V13': -1.71859087457346,
 'V14': -8.636297393652589,
 'V15': -0.24296482145526502,
 'V16': 1.17488417316765,
 'V17': 2.13460635695284,
 'V18': 2.59436483300614,
 'V19': -1.25758897993879,
 'V20': 0.9647718037347099,
 'V21': 0.41747174595057,
 'V22': -0.8173433840569749,
 'V23': -0.0287524020141088,
 'V24': 0.0257225108657227,
 'V25': -0.8258353432218559,
 'V26': -0.0130890304987416,
 'V27': 0.413291188715315,
 'V28': -0.131387346404896,
 'Amount': 0.0,
 'Class': 1}

In [102]:
def remote_predictor_dicts(inputs):
    if type(inputs) is dict: inputs = [inputs]
    instances = [json_format.ParseDict({key:value for key, value in example.items() if key != VAR_TARGET}, Value()) for example in inputs]
    predictions = endpoint.predict(instances=instances).predictions
    return predictions

In [103]:
temp = remote_predictor_dicts(newobs_dicts[0:2000])

In [104]:
newobs_dicts[0]

{'Time': 148074,
 'V1': -2.21921860215056,
 'V2': 0.7278314111063922,
 'V3': -5.45822994652182,
 'V4': 5.92484984705884,
 'V5': 3.9324638237634395,
 'V6': -3.0859842366267003,
 'V7': -1.67786998770016,
 'V8': 0.865074610405235,
 'V9': -3.1772602889458597,
 'V10': -3.4192073840566404,
 'V11': 3.6931739422441203,
 'V12': -3.97843975507806,
 'V13': -1.71859087457346,
 'V14': -8.636297393652589,
 'V15': -0.24296482145526502,
 'V16': 1.17488417316765,
 'V17': 2.13460635695284,
 'V18': 2.59436483300614,
 'V19': -1.25758897993879,
 'V20': 0.9647718037347099,
 'V21': 0.41747174595057,
 'V22': -0.8173433840569749,
 'V23': -0.0287524020141088,
 'V24': 0.0257225108657227,
 'V25': -0.8258353432218559,
 'V26': -0.0130890304987416,
 'V27': 0.413291188715315,
 'V28': -0.131387346404896,
 'Amount': 0.0,
 'Class': 1}

In [105]:
len(temp)

2000

In [106]:
temp[0]

[0.00166472851, 0.998335302]

In [50]:
def remote_predictor_dicts(obs):
    if type(obs) is dict: obs = [obs]
    predictions = []
    for ob in obs:
        ob.pop(VAR_TARGET, None)
        instances = [json_format.ParseDict(ob, Value())]
        predictions.append(endpoint.predict(instances=instances).predictions[0])
    return predictions

In [51]:
remote_predictor_dicts(newobs_dicts[0:2])

[[0.00166473, 0.998335302], [0.0168009363, 0.98319906]]

In [52]:
newobs_dicts[0]

{'Time': 148074,
 'V1': -2.21921860215056,
 'V2': 0.7278314111063922,
 'V3': -5.45822994652182,
 'V4': 5.92484984705884,
 'V5': 3.9324638237634395,
 'V6': -3.0859842366267003,
 'V7': -1.67786998770016,
 'V8': 0.865074610405235,
 'V9': -3.1772602889458597,
 'V10': -3.4192073840566404,
 'V11': 3.6931739422441203,
 'V12': -3.97843975507806,
 'V13': -1.71859087457346,
 'V14': -8.636297393652589,
 'V15': -0.24296482145526502,
 'V16': 1.17488417316765,
 'V17': 2.13460635695284,
 'V18': 2.59436483300614,
 'V19': -1.25758897993879,
 'V20': 0.9647718037347099,
 'V21': 0.41747174595057,
 'V22': -0.8173433840569749,
 'V23': -0.0287524020141088,
 'V24': 0.0257225108657227,
 'V25': -0.8258353432218559,
 'V26': -0.0130890304987416,
 'V27': 0.413291188715315,
 'V28': -0.131387346404896,
 'Amount': 0.0}

In [45]:
test_ds.examples[0]

{'Time': 148074,
 'V1': -2.21921860215056,
 'V2': 0.7278314111063922,
 'V3': -5.45822994652182,
 'V4': 5.92484984705884,
 'V5': 3.9324638237634395,
 'V6': -3.0859842366267003,
 'V7': -1.67786998770016,
 'V8': 0.865074610405235,
 'V9': -3.1772602889458597,
 'V10': -3.4192073840566404,
 'V11': 3.6931739422441203,
 'V12': -3.97843975507806,
 'V13': -1.71859087457346,
 'V14': -8.636297393652589,
 'V15': -0.24296482145526502,
 'V16': 1.17488417316765,
 'V17': 2.13460635695284,
 'V18': 2.59436483300614,
 'V19': -1.25758897993879,
 'V20': 0.9647718037347099,
 'V21': 0.41747174595057,
 'V22': -0.8173433840569749,
 'V23': -0.0287524020141088,
 'V24': 0.0257225108657227,
 'V25': -0.8258353432218559,
 'V26': -0.0130890304987416,
 'V27': 0.413291188715315,
 'V28': -0.131387346404896,
 'Amount': 0.0,
 'Class': 'Fraud'}

In [48]:
VAR_SPECS

{'Time': 'int64',
 'V1': 'float64',
 'V2': 'float64',
 'V3': 'float64',
 'V4': 'float64',
 'V5': 'float64',
 'V6': 'float64',
 'V7': 'float64',
 'V8': 'float64',
 'V9': 'float64',
 'V10': 'float64',
 'V11': 'float64',
 'V12': 'float64',
 'V13': 'float64',
 'V14': 'float64',
 'V15': 'float64',
 'V16': 'float64',
 'V17': 'float64',
 'V18': 'float64',
 'V19': 'float64',
 'V20': 'float64',
 'V21': 'float64',
 'V22': 'float64',
 'V23': 'float64',
 'V24': 'float64',
 'V25': 'float64',
 'V26': 'float64',
 'V27': 'float64',
 'V28': 'float64',
 'Amount': 'float64',
 'Class': 'int64'}

In [47]:
specs = VAR_SPECS.copy()
specs.pop(VAR_TARGET, None)

'int64'