In [2]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
#masked_index = 8
#tokenized_text[masked_index] = '[MASK]'
#assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


INFO:pytorch_pretrained_bert.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt not found in cache, downloading to /tmp/tmp09jbunph
100%|██████████| 231508/231508 [00:00<00:00, 1153497.00B/s]
INFO:pytorch_pretrained_bert.file_utils:copying /tmp/tmp09jbunph to cache at /home/viniciusoliveirasd/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
INFO:pytorch_pretrained_bert.file_utils:creating metadata file for /home/viniciusoliveirasd/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
INFO:pytorch_pretrained_bert.file_utils:removing temp file /tmp/tmp09jbunph
INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/viniciusoliveirasd/.py

In [29]:
tokenized_text
indexed_tokens

[101,
 2040,
 2001,
 3958,
 27227,
 1029,
 102,
 3958,
 103,
 2001,
 1037,
 13997,
 11510,
 102]

In [3]:
# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12

INFO:pytorch_pretrained_bert.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz not found in cache, downloading to /tmp/tmp531y2g2a
100%|██████████| 407873900/407873900 [00:09<00:00, 41377426.43B/s]
INFO:pytorch_pretrained_bert.file_utils:copying /tmp/tmp531y2g2a to cache at /home/viniciusoliveirasd/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
INFO:pytorch_pretrained_bert.file_utils:creating metadata file for /home/viniciusoliveirasd/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
INFO:pytorch_pretrained_bert.file_utils:removing temp file /tmp/tmp531y2g2a
INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /home/viniciusoliveirasd/.pytorch_

In [26]:
def get_embedding(word):
    return encoded_layers[1][0][tokenized_text.index(word)]

get_embedding('who').tolist()

[-0.37139514088630676,
 0.7424401044845581,
 -0.4032290279865265,
 0.6047725081443787,
 0.274798184633255,
 -0.7608024477958679,
 -0.5744450688362122,
 -1.0486557483673096,
 -0.17043378949165344,
 0.5917986035346985,
 -0.08404596149921417,
 0.0378810316324234,
 0.6673643589019775,
 0.4159868359565735,
 0.3696862459182739,
 -0.4248730540275574,
 1.0591756105422974,
 0.6172420978546143,
 -1.445487141609192,
 0.9654020071029663,
 -0.6608734130859375,
 0.07307248562574387,
 -0.0422746017575264,
 1.1849228143692017,
 0.5931152701377869,
 -0.8163939118385315,
 -1.4496400356292725,
 -0.7578036189079285,
 0.8574758172035217,
 0.14887860417366028,
 0.045715589076280594,
 -0.4390874207019806,
 -0.0025805532932281494,
 0.7154507637023926,
 -0.4583241641521454,
 0.3590453267097473,
 0.09835094213485718,
 0.7209110260009766,
 -0.5158544182777405,
 0.23845377564430237,
 0.11905161291360855,
 -1.0053393840789795,
 -0.6091430187225342,
 -0.6154932975769043,
 -0.7832551002502441,
 0.8324711322784424,
 

In [None]:
def get_embedding():
    