In [59]:
import spacy
nlp = spacy.load("en_core_web_sm")

from transformers import BertForMaskedLM, BertTokenizerFast
import torch
import copy
model = BertForMaskedLM.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
tokenizer = BertTokenizerFast.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
MAX_SEQ_LEN = 85
from mlm_utils.model_utils import TOKENIZER

from prepared_for_mlm import masking_sentence_word, get_tokens_for_words, mask_content_words, get_pos_tag
import torch.nn.functional as F
from mlm_utils.preprocess_functions import get_key, compare_tensors, get_pos_tag_word


Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.2 were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## test mask content wword

In [None]:

origin_input_id = torch.tensor([  101,   189, 27226,  1116,  1114,  2549, 21521,  1198,   126,  1106,
         4252,  1320,   122,   117,  1137,  6531,  1121,   172,   118,  1139,
         1665,   117,  1138,  2999,   172,   118,  1139,  1665,   182, 11782,
          1116,  1104,   123,   119,  1512,  1105,   123,   119,   125,   180,
         1830,   117,  1134, 11271,  1120,  1147,   126,  3769,   117,  1229,
         189, 27226,  1116,  1114,  2549, 21521,  1439,  4252,  1320,   122,
         1137, 27553,  1179,   122, 13000,   172,   118,  1139,  1665,   182,
        11782,   1116,   113,   123,   119,   122,   118,   123,   119,   128,
          180,  1830,  1107,   171,   102])

origin_text = tokenizer.decode(origin_input_id, skip_special_tokens=True) 
    
doc = nlp(origin_text)
words_str = [word.text for word in [token for token in doc]]

token_test = tokenizer.encode_plus(
                ' '.join(words_str),
                max_length=MAX_SEQ_LEN,
                padding='max_length', 
                truncation=True,  
                add_special_tokens = True,
                return_tensors="pt",  
                return_attention_mask = True,
                return_offsets_mapping=True  
            )

word_dict = get_tokens_for_words(
    words_str, 
    token_test['input_ids'], 
    token_test['offset_mapping'][0]
    )

masked_sentences, label_ids = mask_content_words(origin_input_id, word_dict)
for i in range(len(masked_sentences)):
    print(label_ids[i], masked_sentences[i], tokenizer.decode(masked_sentences[i], skip_special_tokens=True))


## test is POs match


In [85]:
def get_pos_tag_word(word, text):
    doc = nlp(text)
    word_split = word.split()
    print("================")
    for token in doc:
        if token.text in word_split:
            print(token.text, token.pos_)
        #     return token.pos_
    return None

def get_key(dictionary, value, count_word):
    unique_set = torch.chunk(value, count_word)[0]
    for key, val in dictionary.items():
        if compare_tensors(val, unique_set):
            return key
    return None 
def is_POS_match(input_ids, lm_label_ids, b_count_word):
    '''
    Function to check if the POS tag of the masked token in the logits is the same as the POS tag of the masked token in the original text.
    Note: This function assumes that the logits are of shape # ([85, 28996]) 
    lm_label_ids: shape (batch_size, sequence_length)
    '''
    
    origin_input_id = input_ids.clone() # Origin input id:  torch.Size([85])
    result_ids = input_ids.clone() # torch.Size([85])
    
    # Find the index of the masked token from lm_label_ids
    masked_idx = torch.where(lm_label_ids != -100)[0]
    masked_idx_input = torch.where(input_ids == tokenizer.mask_token_id)[0]
    origin_input_id[masked_idx_input] = lm_label_ids[masked_idx] 
    
    
    # make sure masked_idx_input and masked_idx are the same using assert
    assert torch.equal(masked_idx, masked_idx_input), "Masked index and label index are not the same."
    print("lm_label_ids[masked_idx]  ", lm_label_ids )
    # get pos tag of origin text
    origin_text = tokenizer.decode(origin_input_id, skip_special_tokens=True) 
    
    doc = nlp(origin_text)
    words_str = [word.text for word in [token for token in doc]]
    
    token_test = tokenizer.encode_plus(
                    ' '.join(words_str),
                    max_length=MAX_SEQ_LEN,
                    padding='max_length', 
                    truncation=True,  
                    add_special_tokens = True,
                    return_tensors="pt",  
                    return_attention_mask = True,
                    return_offsets_mapping=True  
                )
    masked_sens_2d = torch.unsqueeze(input_ids, 0)
    label_ids_2d = torch.unsqueeze(lm_label_ids, 0)
    output = model(input_ids = masked_sens_2d, attention_mask = token_test['attention_mask'], token_type_ids=token_test['token_type_ids'], labels=label_ids_2d)
    
    
    word_dict = get_tokens_for_words(
        words_str, 
        token_test['input_ids'], 
        token_test['offset_mapping'][0]
        )
      
    # Get key for the masked token
    
    masked_word = get_key(word_dict, origin_input_id[masked_idx],b_count_word )
    print("masked word: ", masked_word)
    pos_tag_origin = get_pos_tag_word(masked_word, origin_text)
    print("pos tag origin: ", masked_word, pos_tag_origin)         
    
    # # Extract the logits for the masked position
    masked_logits = output.logits[0][masked_idx]
    # print("MASKED LOGITS: ", masked_logits) # torch.Size([28996])
    pred = [torch.argmax(output.logits[0][i]).item() for i in masked_idx]
    # Print top 10 masked tokens
        # print(tokenizer.convert_ids_to_tokens(torch.topk(outputs.logits[0, idx, :], 10).indices))
        
    # Replace the index of the masked token with the list of predicted tokens
    for i in range(len(masked_idx)):
        result_ids[masked_idx[i]] = pred[i]
        
    # print("result sentence: ", tokenizer.decode(result_ids, skip_special_tokens=True))
    for i in pred:
        print("PRED WORD: ",i,  tokenizer.decode(i))
    print("PRED WORD: ",pred,  tokenizer.decode(pred))
    logits_tag = get_pos_tag_word(tokenizer.decode(pred), tokenizer.decode(result_ids, skip_special_tokens=True) )
    print("POS TAG PRED WORD: ", logits_tag)
    
    
    # Cross-entropy term
    
    cross_entropy_term = F.cross_entropy(output.logits.view(-1, tokenizer.vocab_size), lm_label_ids.view(-1))
    print("Cross entropy term shape: ", cross_entropy_term.shape)      ##Cross entropy term shape:  torch.Size([2720])
  
    # Custom matching term
    matching_term_batch = (pos_tag_origin == logits_tag)

    # Combine terms to get the loss for 1 batch
    matching_term = torch.where(matching_term_batch == torch.tensor(True), torch.tensor(1.0), torch.tensor(0.0))
    print("Matching term shape: ", matching_term)  ##Matching term shape:  torch.Size([2720])
    print("cross entropy term: ", cross_entropy_term)
    
    loss = 0.5 * cross_entropy_term + (1 - matching_term)
    
    
    return loss

In [87]:
label_id = torch.tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 5190, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 5190, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100])

input_id= torch.tensor([101, 1748, 4982, 117, 1103, 3687, 26883, 1320, 1104, 194, 3031, 1162, 2423, 170, 15792, 19033, 103, 4789, 1107, 19255, 118, 4065, 3652, 117, 8783, 1115, 194, 3031, 1162, 2399, 170, 3607, 1648, 1107, 19255, 118, 4065, 103, 4789, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

print(is_POS_match(input_id, label_id, 2))

lm_label_ids[masked_idx]   tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, 5190, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, 5190, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100])
masked word:  acid
acid NOUN
acid NOUN
pos tag origin:  acid None
PRED WORD:  26825 insulin
PRED WORD:  3850 drug
PRED WORD:  [26825, 3850] insulin drug
insulin NOUN
drug NOUN
POS TAG PRED WORD:  None
Cross entropy term shape:  torch.Size([])
Matching term shape:  tensor(1.)
cross entropy term:  tensor(4.9894, grad_fn=<NllLossBackward0>)
tensor(2.4947, grad_fn=<AddBackward0>)


In [50]:
import torch

# Sample word dictionary and tensor
word_dict = {'tumours': [torch.tensor(189)], 'with': [torch.tensor(1114)]}
given_tensor = torch.tensor([189, 27226, 1116, 189, 27226, 1116])

# Function to check if any tensor in the dictionary values is in the given tensor
def is_in_tensor(word_dict, tensor):
    for word, tensors in word_dict.items():
        for tensor_val in tensors:
            if (tensor == tensor_val).any().item():
                print(f"{word} is in the given tensor.")
                break

# Check if any word in the dictionary is in the given tensor
is_in_tensor(word_dict, given_tensor)


tumours is in the given tensor.
