In [1]:
WRITE_TOKEN = "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"
READ_TOKEN = "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"
PATH = 'path/to/model'

# TruLens Observability

In [37]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

# Wrap all of the necessary components.
class SimpleEnglishClassifier:
    model_name = PATH.split('/')[-1]

    device = 'cuda:0'

    tokenizer = DistilBertTokenizer.from_pretrained(PATH, use_safetensors=True )
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    model = DistilBertForSequenceClassification.from_pretrained(PATH, use_safetensors=True).to(device)

    labels = ["simple", "complex"]

    SIMPLE = labels.index('simple')
    COMPLEX = labels.index('complex')

task = SimpleEnglishClassifier()

Model Wrapper

In [41]:
from IPython.display import display
import torch

In [42]:
# %pip install trulens

In [43]:
from trulens.nn.models import get_model_wrapper
from trulens.nn.quantities import ClassQoI
from trulens.nn.attribution import IntegratedGradients
from trulens.nn.attribution import Cut, OutputCut
from trulens.utils.typing import ModelInputs

task.wrapper = get_model_wrapper(task.model, device=task.device)

INFO:trulens:lib level=1
INFO:trulens:root level=30
INFO:trulens:Detected pytorch backend for <class 'transformers.models.distilbert.modeling_distilbert.DistilBertForSequenceClassification'>.
INFO:trulens:Changing backend from None to Backend.PYTORCH.
INFO:trulens:If this seems incorrect, you can force the correct backend by passing the `backend` parameter directly into your get_model_wrapper call.


Attributions

In [44]:
task.wrapper.print_layer_names()

'distilbert_embeddings_word_embeddings':	Embedding(30522, 768, padding_idx=0)
'distilbert_embeddings_position_embeddings':	Embedding(512, 768)
'distilbert_embeddings_LayerNorm':	LayerNorm((768,), eps=1e-12, elementwise_affine=True)
'distilbert_embeddings_dropout':	Dropout(p=0.1, inplace=False)
'distilbert_transformer_layer_0_attention_dropout':	Dropout(p=0.1, inplace=False)
'distilbert_transformer_layer_0_attention_q_lin':	Linear(in_features=768, out_features=768, bias=True)
'distilbert_transformer_layer_0_attention_k_lin':	Linear(in_features=768, out_features=768, bias=True)
'distilbert_transformer_layer_0_attention_v_lin':	Linear(in_features=768, out_features=768, bias=True)
'distilbert_transformer_layer_0_attention_out_lin':	Linear(in_features=768, out_features=768, bias=True)
'distilbert_transformer_layer_0_sa_layer_norm':	LayerNorm((768,), eps=1e-12, elementwise_affine=True)
'distilbert_transformer_layer_0_ffn_dropout':	Dropout(p=0.1, inplace=False)
'distilbert_transformer_layer_0

Parameters

In [74]:
infl_max = IntegratedGradients(
    model = task.wrapper,
    doi_cut=Cut('distilbert_embeddings_word_embeddings'),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

# Alternatively we can look at a particular class:
infl_complex = IntegratedGradients(
    model = task.wrapper,
    doi_cut=Cut('distilbert_embeddings_word_embeddings'),
    qoi=ClassQoI(task.COMPLEX),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

infl_simple = IntegratedGradients(
    model = task.wrapper,
    doi_cut=Cut('distilbert_embeddings_word_embeddings'),
    qoi=ClassQoI(task.SIMPLE),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

In [66]:
# attributions uses the same call as model evaluation
attrs = infl_max.attributions(**inputs)

for token_ids, token_attr in zip(inputs['input_ids'], attrs):
    for token_id, token_attr in zip(token_ids, token_attr):
        # Not that each `word_attr` has a magnitude for each of the embedding
        # dimensions, of which there are many. We aggregate them for easier
        # interpretation and display.
        attr = token_attr.sum()

        word = task.tokenizer.decode(token_id)

        print(f"{word}({attr:0.3f})", end=' ')

    print()

[ C L S ](0.001) h o w e v e r(-0.213) ,(-0.161) r i n g(-0.120) # # e(-0.043) n o t e s(-0.195) t h a t(-0.079) w h i l e(-0.115) t h i s(-0.076) e t y m o l o g y(0.075) i s(-0.022) s e m a n t i c(-0.016) # # a l l y(0.000) p l a u s i b l e(-0.051) ,(-0.097) a(-0.056) w o r d(-0.009) m e a n i n g(-0.011) "(-0.020) b r o w n(0.009) "(-0.038) o f(0.004) t h i s(-0.071) f o r m(0.036) c a n n o t(-0.053) b e(-0.048) f o u n d(-0.079) i n(-0.015) p r o t o(0.018) -(-0.054) i n d o(0.024) -(-0.080) e u r o p e a n(-0.024) .(0.013) h e(-0.043) s u g g e s t s(-0.076) i n s t e a d(-0.085) t h a t(-0.011) "(-0.037) b e a r(0.004) "(-0.022) i s(-0.016) f r o m(-0.009) t h e(-0.038) p r o t o(-0.007) -(-0.013) i n d o(0.008) -(-0.001) e u r o p e a n(-0.018) w o r d(-0.008) *(-0.018) g(-0.020) # # ʰ(-0.008) # # w e r(-0.020) -(0.037) ~(-0.002) *(-0.009) g(-0.011) # # ʰ(-0.013) # # w e r(0.006) "(-0.006) w i l d(0.008) a n i m a l(0.002) "(-0.003) .(-0.010) [ S E P ](0.003) 
[ C L S ](-0.00

Visualization

In [75]:
from trulens.visualizations import NLP

V = NLP(
    wrapper=task.wrapper,
    labels=task.labels,
    decode=lambda x: task.tokenizer.decode(x),
    tokenize=lambda sentences: ModelInputs(kwargs=task.tokenizer(sentences, padding=True, return_tensors='pt')).map(lambda t: t.to(task.device)),
    # huggingface models can take as input the keyword args as per produced by their tokenizers.

    input_accessor=lambda x: x.kwargs['input_ids'],
    # for huggingface models, input/token ids are under input_ids key in the input dictionary

    output_accessor=lambda x: x['logits'],
    # and logits under 'logits' key in the output dictionary

    hidden_tokens=set([task.tokenizer.pad_token_id])
    # do not display these tokens
)

print("QOI = MAX PREDICTION")
display(V.token_attribution(texts, infl_max))

print("QOI = COMPLEX")
display(V.token_attribution(texts, infl_complex))

print("QOI = SIMPLE")
display(V.token_attribution(texts, infl_simple))

QOI = MAX PREDICTION


QOI = COMPLEX


QOI = SIMPLE


Baselines

In [72]:
from trulens.utils.nlp import token_baseline

inputs_baseline_ids, inputs_baseline_embeddings = token_baseline(
    keep_tokens=set([task.tokenizer.cls_token_id, task.tokenizer.sep_token_id]),
    # Which tokens to preserve.

    replacement_token=task.tokenizer.pad_token_id,
    # What to replace tokens with.

    input_accessor=lambda x: x.kwargs['input_ids'],

    ids_to_embeddings=task.model.get_input_embeddings()
    # Callable to produce embeddings from token ids.
    )


In [70]:
print("originals=", task.tokenizer.batch_decode(inputs['input_ids']))

baseline_word_ids = inputs_baseline_ids(model_inputs=ModelInputs(args=[], kwargs=inputs))
print("baselines=", task.tokenizer.batch_decode(baseline_word_ids))

originals= ['[CLS] however, ringe notes that while this etymology is semantically plausible, a word meaning " brown " of this form cannot be found in proto - indo - european. he suggests instead that " bear " is from the proto - indo - european word * gʰwer - ~ * gʰwer " wild animal ". [SEP]', '[CLS] the neolithic revolution was the first agricultural revolution. it was a gradual change from nomadic hunting and gathering communities to agriculture and settlement. [ 1 ] it changed the way of life of the communities which made the change. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']
baselines= ['[CLS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

In [79]:
infl_complex_baseline = IntegratedGradients(
    model = task.wrapper,
    resolution=10,
    baseline = inputs_baseline_embeddings,
    doi_cut=Cut('distilbert_embeddings_word_embeddings'),
    qoi=ClassQoI(task.COMPLEX),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

infl_simple_baseline = IntegratedGradients(
    model = task.wrapper,
    resolution=10,
    baseline = inputs_baseline_embeddings,
    doi_cut=Cut('distilbert_embeddings_word_embeddings'),
    qoi=ClassQoI(task.SIMPLE),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

print("QOI = COMPLEX WITH BASELINE")
display(V.token_attribution(texts, infl_complex_baseline))

print("QOI = SIMPLE WITH BASELINE")
display(V.token_attribution(texts, infl_simple_baseline))

QOI = COMPLEX WITH BASELINE


QOI = SIMPLE WITH BASELINE
