In [20]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

In [3]:
MODEL_NAME = "deepset/gbert-base"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(31102, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [35]:
def get_vectors(text):
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state
    return embeddings, tokens

In [107]:
embeddings_1, tokens_1 = get_vectors("Ich sitze auf der Bank.")
embeddings_2, tokens_2 = get_vectors("Diese Bank ist aus Holz.")
embeddings_3, tokens_3 = get_vectors("Ich lege mein Geld auf dieser Bank in Wertpapiere an.")
embeddings_4, tokens_4 = get_vectors("Ich lege mein Geld bei dieser Bank in Wertpapiere an.")

In [108]:
embeddings_1 = embeddings_1.squeeze(0)
embeddings_2 = embeddings_2.squeeze(0)
embeddings_3 = embeddings_3.squeeze(0)
embeddings_4 = embeddings_4.squeeze(0)

In [109]:
print(embeddings_1.shape)
print(embeddings_2.shape)
print(embeddings_3.shape)
print(embeddings_4.shape)

torch.Size([9, 768])
torch.Size([8, 768])
torch.Size([15, 768])
torch.Size([15, 768])


In [110]:
print(len(tokens_1))
print(len(tokens_2))
print(len(tokens_3))
print(len(tokens_4))

9
8
15
15


In [113]:
print(tokens_1)
print(tokens_1[6])
print(tokens_2)
print(tokens_2[2])
print(tokens_3)
print(tokens_3[8])
print(tokens_4)
print(tokens_4[8])

['[CLS]', 'Ich', 'sit', '##ze', 'auf', 'der', 'Bank', '.', '[SEP]']
Bank
['[CLS]', 'Diese', 'Bank', 'ist', 'aus', 'Holz', '.', '[SEP]']
Bank
['[CLS]', 'Ich', 'leg', '##e', 'mein', 'Geld', 'auf', 'dieser', 'Bank', 'in', 'Wertpapier', '##e', 'an', '.', '[SEP]']
Bank
['[CLS]', 'Ich', 'leg', '##e', 'mein', 'Geld', 'bei', 'dieser', 'Bank', 'in', 'Wertpapier', '##e', 'an', '.', '[SEP]']
Bank


In [116]:
v_1 = embeddings_1[6]
v_2 = embeddings_2[2]
v_3 = embeddings_3[8]
v_4 = embeddings_4[8]

In [117]:
print("1, 2", F.cosine_similarity(v_1, v_2, dim=0).item())
print("1, 3", F.cosine_similarity(v_1, v_3, dim=0).item())
print("1, 4", F.cosine_similarity(v_1, v_4, dim=0).item())
print("2, 3", F.cosine_similarity(v_2, v_3, dim=0).item())
print("2, 4", F.cosine_similarity(v_2, v_4, dim=0).item())
print("3, 4", F.cosine_similarity(v_3, v_4, dim=0).item())

1, 2 0.6724907755851746
1, 3 0.6367990970611572
1, 4 0.6208581328392029
2, 3 0.6761319041252136
2, 4 0.693832278251648
3, 4 0.8559741377830505
