# BERT inference

In [7]:
import torch
from transformers import BertModel, BertTokenizer

# Load BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Function to get embeddings for each token in the input query
def get_token_embeddings(query):
    #inputs = tokenizer(query, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
    inputs = tokenizer(query, return_tensors='pt', truncation=True, padding=False, max_length=128)
    
    outputs = bert_model(**inputs)
    # outputs.last_hidden_state contains the embeddings for all tokens
    token_embeddings = outputs.last_hidden_state
    return token_embeddings

# Example query
query = "Find books about deep learning"
token_embeddings = get_token_embeddings(query)

print(token_embeddings.shape)  # Shape: (batch_size, sequence_length, hidden_size)

# To get the [CLS] token's embedding
cls_embedding = token_embeddings[:, 0, :]
print(cls_embedding.shape)  # Shape: (batch_size, hidden_size)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 7, 768])
torch.Size([1, 768])


In [25]:
print(outputs.last_hidden_state.shape)

torch.Size([1, 7, 768])


In [8]:
cls_embedding.shape

torch.Size([1, 768])

In [10]:
token_embeddings.shape

torch.Size([1, 7, 768])

In [16]:
query = "Find books about deep learning"
token_embeddings = get_token_embeddings(query)

print(token_embeddings[0, 2, :10])

query = "Find books about deep water"
token_embeddings = get_token_embeddings(query)

print(token_embeddings[0, 2, :10])


tensor([ 0.1819,  0.7634,  0.3809, -0.1409,  1.0556, -1.1763,  0.1290,  0.3763,
         0.4035,  0.5729], grad_fn=<SliceBackward0>)
tensor([ 0.3422,  0.6524,  0.4581, -0.1221,  1.0180, -1.2297,  0.0068,  0.5418,
         0.3652,  0.4390], grad_fn=<SliceBackward0>)
