In [8]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import head_view

utils.logging.set_verbosity_error()  # Suppress standard warnings
tokenizer = AutoTokenizer.from_pretrained("peterchou/ernie-gram")
model = AutoModel.from_pretrained("peterchou/ernie-gram", output_attentions=True)
inputs = tokenizer.encode("北京四维图新科技有限公司", return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0])
head_view(attention, tokens)

<IPython.core.display.Javascript object>

In [6]:
from bertviz import model_view
from transformers import AutoTokenizer, AutoModel, utils

utils.logging.set_verbosity_error()  # Suppress standard warnings
model_version = 'peterchou/ernie-gram'
sentence_a = "北京四维图新科技有限公司"
sentence_b = "四维图新科技有限公司"

model = AutoModel.from_pretrained(model_version, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_version)
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt')
input_ids = inputs['input_ids']
token_type_ids = inputs['token_type_ids'] # token type id is 0 for Sentence A and 1 for Sentence B
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
sentence_b_start = token_type_ids[0].tolist().index(1) # Sentence B starts at first index of token type id 1
token_ids = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(token_ids)    
model_view(attention, tokens, sentence_b_start)

<IPython.core.display.Javascript object>

In [None]:
from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show
model_type = 'bert'
model_version = 'bert-base-chinese'
do_lower_case = True
model = BertModel.from_pretrained(model_version)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)
sentence_a = "北京四维图新科技有限公司"
sentence_b = "四维图新科技有限公司"
show(model, model_type, tokenizer, sentence_a, sentence_b, display_mode='dark', layer=2, head=0)