In [45]:
from dataclasses import dataclass

@dataclass
class Mention:
    """Entity mention and its context"""
    mention: str
    doc_title: str
    left_context: str
    right_context: str
        
    def prepare_for_tokenizer(self):
#         return f"[CLS]{self.doc_title}[SEP]{self.left_context}[E]{self.mention}[/E]{self.right_context}[SEP]"
        return f"{self.doc_title}[SEP]{self.left_context}[E]{self.mention}[/E]{self.right_context}"

In [67]:
@dataclass
class Entity:
    """Entity and its description. For the future - maybe add synonyms and neighbours"""
#     entity: str
    description: str
#     qid: str
        
    def prepare_for_tokenizer(self):
        return f"{self.description}"

In [57]:
from transformers import BertTokenizer, BertModel

In [58]:
tok = BertTokenizer.from_pretrained("emanjavacas/GysBERT")

In [59]:
model = BertModel.from_pretrained("emanjavacas/GysBERT")

Some weights of the model checkpoint at emanjavacas/GysBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [60]:
special_tokens_dict = {'additional_special_tokens': ['[E]','[/E]']}
num_added_toks = tok.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tok))

Embedding(30002, 768)

In [61]:
ment = Mention(mention='Amsterdam', doc_title='Aan de Amsterdamse grachten',
              left_context='Er staat een huis aan de gracht in oud',
              right_context='waar ik als jochie van acht bij grootmoeder kwam')

In [None]:
ent = "stad in Noord-Holland, Nederland"

In [68]:
def tokenize_entry(entry):
    '''
    entry is either mention (with context) or entity (with description)
    '''
    input_line = mention.prepare_for_tokenizer()
    return tok(input_line,return_tensors='pt')['input_ids']

In [63]:
tokens = tokenize_mention(ment)
tokens

{'input_ids': tensor([[    2,  1490,  1448,  1945,  1803, 23511,     3,  1557,  2010,  1473,
          2026,  1490,  1448, 12050,  1464,  2602, 30000,  1945, 30001,  1562,
          1642,  1560,  4408, 20065,   905,  1455,  2408,  1534, 19927,  2601,
             3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1]])}

In [64]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30002, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
   

In [66]:
res = model(tokens)
res

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.6718,  0.2235, -0.2448,  ..., -0.5155,  0.3100,  0.4625],
         [ 0.0170, -1.0129, -0.6136,  ..., -0.7572,  0.8658, -0.0089],
         [-0.2262, -0.9541, -0.8106,  ..., -0.3831,  0.6160,  0.4601],
         ...,
         [ 0.4031,  0.5679, -0.7611,  ..., -0.4276, -0.0226,  1.6612],
         [ 0.5952, -0.2189, -0.3749,  ..., -0.8822,  0.8337,  0.3204],
         [ 0.4837, -0.2339, -0.6935,  ..., -0.4878,  0.6203,  0.1989]]],
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[-0.1874,  0.0314, -0.0723, -0.9831, -0.1143,  1.0000,  0.1045,  0.0174,
          0.1398, -0.0815, -0.2276, -0.4845, -0.8867,  0.9994, -0.0050, -0.6438,
         -0.5974, -0.4487, -0.1266,  0.2979,  0.5769,  0.0676, -0.0367,  0.8939,
          0.5296,  0.0498, -0.0557,  0.5008,  0.6975, -0.0699, -1.0000,  0.9631,
          1.0000,  0.1897, -0.2364, -0.9996, -1.0000,  0.0640, -0.1700,  0.9998,
          0.0656, -0.1328,  0.923