In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os

# contains pretrained model, e.g. for Wordnet+Wikipedia
WORDNET_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wordnet_model.tar.gz"
WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_model.tar.gz"
WORDNET_WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz"


KNOWLEDGE_DIR = '../canlpy/knowledge/knowbert/'
WORDNET_DIR = KNOWLEDGE_DIR + 'wordnet/'
WORDNET_LINKER_FOLDER = WORDNET_DIR + 'entity_linker/'
WORDNET_LINKER_EMBEDDING_FILE = WORDNET_LINKER_FOLDER + 'wordnet_synsets_mask_null_vocab_embeddings_tucker_gensen.hdf5'
WORDNET_LINKER_ENTITY_FILE = WORDNET_LINKER_FOLDER + 'entities.jsonl'
WORDNET_LINKER_VOCAB_FILE = WORDNET_LINKER_FOLDER + 'wordnet_synsets_mask_null_vocab.txt'

PRE_TRAINED_DIR = '../canlpy/pretrained_models/knowbert/wordnet/'
WORDNET_MODEL_STATE_DICT_FILE = PRE_TRAINED_DIR+ 'weights.th'

In [3]:
# from kb.custom_knowbert import CustomKnowBert
# from kb.soldered_kg import CustomSolderedKG, CustomEntityLinkingWithCandidateMentions
# from kb.custom_knowledge import CustomWordNetAllEmbedding
from canlpy.core.models.knowbert.knowbert import CustomKnowBert
from canlpy.core.models.knowbert.soldered_kg import CustomSolderedKG, CustomEntityLinkingWithCandidateMentions
from canlpy.core.models.knowbert.knowledge import CustomWordNetAllEmbedding


span_attention_config = {'hidden_size': 200, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}
span_encoder_config = {'hidden_size': 200, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}

#117662
null_entity_id = 117662 #model.vocab.get_token_index('@@NULL@@', "entity")
entity_dim = 200

model_entity_embedder = CustomWordNetAllEmbedding(
                 embedding_file = WORDNET_LINKER_EMBEDDING_FILE,
                 entity_dim = entity_dim,
                 entity_file = WORDNET_LINKER_ENTITY_FILE,
                 vocab_file= WORDNET_LINKER_VOCAB_FILE,
                 entity_h5_key = "tucker_gensen",
                 dropout = 0.1,
                 pos_embedding_dim = 25,
                 include_null_embedding = False)

entity_embeddings = model_entity_embedder.entity_embeddings
null_embedding = torch.zeros(entity_dim) #From wordnet code

custom_entity_linker = CustomEntityLinkingWithCandidateMentions(
                 null_entity_id=null_entity_id,
                 entity_embedding = model_entity_embedder,
                 contextual_embedding_dim =768,
                 span_encoder_config = span_encoder_config,
                 margin = 0.2,
                 decode_threshold = 0.0,
                 loss_type = 'softmax',
                 max_sequence_length = 512,
                 dropout = 0.1,
                 output_feed_forward_hidden_dim = 100,
                 initializer_range = 0.02)

custom_wordnet_kg = CustomSolderedKG(entity_linker = custom_entity_linker, 
                            span_attention_config = span_attention_config,
                            should_init_kg_to_bert_inverse = False,
                            freeze = False)

custom_soldered_kgs = {'wordnet':custom_wordnet_kg}

In [4]:
span_extractor_global_attention_old_name = "wordnet_soldered_kg.entity_linker.disambiguator.span_extractor._global_attention._module.weight"
span_extractor_global_attention_bias_old_name = "wordnet_soldered_kg.entity_linker.disambiguator.span_extractor._global_attention._module.bias"
state_dict_map = {span_extractor_global_attention_old_name:span_extractor_global_attention_old_name.replace("._module",""),
                span_extractor_global_attention_bias_old_name: span_extractor_global_attention_bias_old_name.replace("._module","")}

custom_model = CustomKnowBert(soldered_kgs = custom_soldered_kgs,
                                soldered_layers ={"wordnet": 9},
                                bert_model_name = "bert-base-uncased",
                                mode=None,state_dict_file=WORDNET_MODEL_STATE_DICT_FILE,
                                strict_load_archive=True,
                                remap_segment_embeddings = None,
                                state_dict_map = state_dict_map)



In [8]:
test_set = torch.load("test_set")

def custom_equal(a,b):
    return torch.allclose(a,b, atol=1e-04)
custom_model.eval()
#max_diff of 3.4332275390625e-05
for test_case in test_set:
    
    custom_output = custom_model(**test_case["input"])
    expected_output = test_case["expected_output"]
   
    equal = custom_equal(expected_output['wordnet']['entity_attention_probs'],custom_output['wordnet']['entity_attention_probs'])
    print(f"wordnet entity_attention_probs are equal: {equal}")
    assert(equal)
    equal = custom_equal(expected_output['wordnet']['linking_scores'],custom_output['wordnet']['linking_scores'])
    print(f"Output linking scores are equal: {custom_equal(expected_output['wordnet']['linking_scores'],custom_output['wordnet']['linking_scores'])}")
    assert(equal)
    equal = (expected_output['loss']==custom_output['loss'])
    print(f"Loss are equal : {expected_output['loss']==custom_output['loss']}")
    assert(equal)
    equal = custom_equal(expected_output['pooled_output'],custom_output['pooled_output'])
    print(f"Pooled outputs are equal : {equal}")
    assert(equal)
    equal = custom_equal(expected_output['contextual_embeddings'],custom_output['contextual_embeddings'])
    assert(equal)

wordnet entity_attention_probs are equal: True
Output linking scores are equal: True
Loss are equal : True
Pooled outputs are equal : True
wordnet entity_attention_probs are equal: True
Output linking scores are equal: True
Loss are equal : True
Pooled outputs are equal : True
wordnet entity_attention_probs are equal: True
Output linking scores are equal: True
Loss are equal : True
Pooled outputs are equal : True
wordnet entity_attention_probs are equal: True
Output linking scores are equal: True
Loss are equal : True
Pooled outputs are equal : True


In [7]:
sentences = ["Paris is located in France.", "Michael Jackson is a great music singer"]
# batcher takes raw untokenized sentences
# and yields batches of tensors needed to run KnowBert
for i,batch in enumerate(batcher.iter_batches(sentences, verbose=True)):

    print(f"\nInput\n")
    print(f"Batch: {batch.keys()}") #Batch contains {tokens,segment_ids,candidates}
    #tokens: Tensor of tokens indices (used to idx an embedding) => because a batch contains multiple
    #sentences with varying # of tokens, all tokens tensors are padded with zeros 
    #shape: (batch_size (#sentences), max_seq_len)
    #print(batch['tokens'])#dict with only 'tokens'
    print(f"Tokens shape {batch['tokens']['tokens'].shape}")
    #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
    #shape: (batch_size,max_seq_len)
    print(f"Segment ids shape: {batch['segment_ids'].shape}")

    #Dict with only wordnet
    #Candidates: stores for multiple knowledge base, the entities detected using this knowledge base
    wordnet_kb = batch['candidates']['wordnet']
    print(f"Wordnet kb: {wordnet_kb.keys()}")
    
    #Stores for each detected entities, a list of candidate KB entities that correspond to it
    #Priors: correctness probabilities estimated by the entity linker (sum to 1 (or 0 if padding) on axis 2)
    #Adds 0 padding to axis 1 when there is less detected entities in the sentence than in the max sentence
    #Adds 0 padding to axis 2 when there is less detected KB entities for an entity in the sentence than in the max candidate KB entities entity
    #shape:(batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entity_priors shape: {wordnet_kb['candidate_entity_priors'].shape}")
    #Ids of the KB candidate entities + 0 padding on axis 1 or 2 if necessary
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entities ids shape: {wordnet_kb['candidate_entities']['ids'].shape}")
    #Spans of which sequence of tokens correspond to an entity in the sentence, eg: [1,2] for Michael Jackson (both bounds are included)
    #Padding with [-1,-1] when no more detected entities
    #shape: (batch_size, max # detected entities, 2)
    print(f"Candidate span shape: {wordnet_kb['candidate_spans'].shape}")

    #For each sentence entity, indicate to which segment ids it corresponds to
    #shape: (batch_size, max # detected entities)
    print(f"Candidate segments_ids shape: {wordnet_kb['candidate_segment_ids'].shape}")

    #model(**batch) <=> model(tokens = batch['tokens'],segment_ids=batch['segment_ids'],candidates=batch['candidates']) 
    model_output = model(**batch)
    
    print(f"\nOutput\n")
    print(f"Model output keys: {model_output.keys()}")
    print(f"Output wordnet keys: {model_output['wordnet'].keys()}")
    #Span attention layers scores for wordnet KB
    #shape: (batch_size,?,max_seq_len,max # detected entities)
    print(f"Output wordnet entity_attention_probs shape: {model_output['wordnet']['entity_attention_probs'].shape}")
    #Entity linker score for each text entity and possible KB entity, -1.0000e+04 padding in case of no score
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Output wordnet linking_scores shape: {model_output['wordnet']['linking_scores'].shape}")
    
    #Scalar indicating loss over this batch (0 if not training?)
    print(f"Output loss: {model_output['loss']}")

    #Final CLS embedding for each sentence of batch
    # shape: (batch_size, hidden_size) 
    print(f"Pooled output shape: {model_output['pooled_output'].shape}")

    #For each tokens, its final embeddings
    #Important!!!, still predicts something for 0 padded tokens => ignore (or 0 padding <=> MASK???)
    print(f"Contextual embeddings: {model_output['contextual_embeddings'].shape}")

NameError: name 'batcher' is not defined

In [None]:
    #TODO: see how to add masking => 0 idx tokens embedding?
    #TODO: See how to extract from final embeddings the actual predicted tokens
    #TODO: copy paste all allennlp dependencies in an allennlp.py file that contains all classes => get rid of dependency

In [None]:
for name, param in model.named_parameters():
    print(f"{name}:{param.shape}")

pretrained_bert.bert.embeddings.word_embeddings.weight:torch.Size([30522, 768])
pretrained_bert.bert.embeddings.position_embeddings.weight:torch.Size([512, 768])
pretrained_bert.bert.embeddings.token_type_embeddings.weight:torch.Size([2, 768])
pretrained_bert.bert.embeddings.LayerNorm.weight:torch.Size([768])
pretrained_bert.bert.embeddings.LayerNorm.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.output.dense.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.la