# Bert Visualize
> Visualize masked language modeling transformer model

In [1]:
# default_exp bert_visualize

In [2]:
# !pip install transformers

In [3]:
from transformers import AutoModelForMaskedLM,AutoTokenizer
from forgebox.imports import *
from forgebox.config import Config

In [4]:
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",use_fast=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMaskedLM were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




A piece of sample text

In [5]:
text = """I must not [MASK].
Fear is the mind-killer.
Fear is the little [MASK] that brings total obliteration.
I will face my fear.
I will permit it to pass over me and through me.
And when it has gone past I will turn the inner [MASK] to see its path.
Where the fear has gone there will be nothing.
Only I will remain."""

In [6]:
class MLMVisualizer:
    def __init__(self,model,tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        
    @classmethod
    def from_pretrained(cls,
                        tag:"str, like how you use from_pretrained from transformers"
                       ):
        obj = cls(
                model = AutoModelForMaskedLM.from_pretrained(tag),
                tokenizer = AutoTokenizer.from_pretrained(tag,use_fast=True),
        )
        return obj
        
    def tok(self,text:str,)->[
            torch.FloatTensor,
            torch.BoolTensor,
            list,
        ]:
        """
        A specific way of tokenizing.
            with pytorch tensor as input
            with mask tensor specifying where's the [MASK] token
            with offset mapping marking the positions 
                in format of list in list
        """
        tokenized = self.tokenizer(
            text,
            return_tensors = "pt",
            return_offsets_mapping=True
        )
        x = tokenized['input_ids']
        offset_mapping = tokenized['offset_mapping']
        mask = x==self.tokenizer.mask_token_id
        return x,mask,offset_mapping

In [7]:
vis = MLMVisualizer.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMaskedLM were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
def li(x,)->np.array:
    if torch.is_tensor(x):
        x=x.cpu().numpy()
    return x.tolist()

In [9]:
def infer_logits(
        vis,
        y_pred,
        mask) -> Config:
    logits = softmax(y_pred[mask])
    pred_idx = logits.argmax(-1)
    return Config(
        logits=logits,
        pred_idx=pred_idx,
        pred_tokens = tokenizer.convert_ids_to_tokens(pred_idx)
    )


MLMVisualizer.infer_logits = infer_logits

In [10]:
def predict_text(
        vis,
        text,
           )->Config:
    with torch.no_grad():
        x,mask,mapper=vis.tok(text)
        y_pred,attention = vis.model(x,output_attentions=True)
        infered = vis.infer_logits(y_pred,mask)
    return Config(
        text = text,
        x = li(x),
        mask = li(mask),
        mapper = li(mapper),
#         y_pred = li(y_pred),
#         logits = li(infered.logits),
        pred_idx=li(infered.pred_idx),
        pred_tokens =infered.pred_tokens,
        attention = list(map(li,attention)),
    )
MLMVisualizer.predict_text = predict_text

In [31]:
from jinja2 import Template
from forgebox.html import DOM
import json
from uuid import uuid4

In [35]:
def visualize(vis,
              text):
    result = vis.predict_text(text)
    vis.visualize_result(result)


def visualize_result(vis, result: Config):
    with open('mlm_visual.html', 'r') as f:
        template = Template(f.read())
    with open('mlm_visual.js', 'r') as f:
        js = f.read()
    text = result.text
    delattr(result, 'text')
    output_id = str(uuid4())
    page = template.render(data=json.dumps(result),
                           text=text,
                           output_id=output_id,
                           mlm_visual_js=js)
    DOM(page, "div",)()


MLMVisualizer.visualize = visualize
MLMVisualizer.visualize_result = visualize_result

In [13]:
softmax = nn.Softmax(dim=-1)

In [25]:
%%time
result = predict_text(vis,text)

CPU times: user 789 ms, sys: 23.7 ms, total: 813 ms
Wall time: 231 ms


In [None]:
%%time
vis.visualize(text)