In [1]:
import re
import torch
import numpy as np
from tqdm import tqdm

In [2]:
import collections

In [3]:
from transformers import *

In [4]:
from util.util import *

In [5]:
model = BertModel.from_pretrained('/data/medg/misc/phuongpm/biobert_v1.1_pubmed')
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [6]:
tokenizer = BertTokenizer.from_pretrained('/data/medg/misc/phuongpm/biobert_v1.1_pubmed')

In [7]:
filename = "/data/medg/misc/phuongpm/" + "test1.0.json"

In [8]:
data = JsonDataset(filename)

In [9]:
sample_data = list(data.json_to_plain(remove_notfound=True, doc_ent=True))

In [10]:
input0 = sample_data[0]

In [11]:
document, query, candidates, answers = input0['p'], input0['q'], input0['c'], input0['a']

In [12]:
def ent_to_plain_doc(document):
    tokens = document.split()
    for i, t in enumerate(tokens):
        if t.startswith('@entity'):
            tokens[i] = t.replace("@entity","").replace("_", " ")
    return ' '.join(tokens)

In [13]:
answers

'marginal zone lymphoma'

In [14]:
document

"more than meets the eye : the ‘ pink salmon patch ’ summary @entityocular_adnexal_lymphomas account for 1 – 2 % of all @entitynon-hodgkin_'_s_lymphomas . conjunctiva is the primary site of involvement in one - third of cases . we present a case of a 47 - year - old hispanic woman who presented with @entityleft_eye_itching and @entityirritation associated with a @entitypainless_pink_mass . @entityphysical_examination revealed the presence of a ‘ pink salmon - patch ’ involving her left medial conjunctiva . @entityorbital_ct showed a @entitysubcentimeter_left_preseptal_soft_tissue_density . @entitybiopsy revealed a @entitydense_subepithelial_lymphoid_infiltrate comprised predominantly of b cells that did not coexpress cd5 or cd43 . these @entityfindings were consistent with @entityb_-_cell_marginal_zone_lymphoma . @entityfurther_staging_assessment did not reveal @entitydisseminated_disease . she had @entitystage_1e_extranodal_marginal_zone_lymphoma as per ann arbor staging system . she 

In [15]:
def get_embedding(sent, dot=True):
    if dot:
        sent = sent + ' .'
    text = '[CLS] {} [SEP]'.format(sent)
    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # not taking the masked tokens into account

    segments_ids = [0]*len(indexed_tokens)
    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    with torch.no_grad():
        # See the models docstrings for the detail of the inputs
        outputs = model(tokens_tensor, token_type_ids=segments_tensors)
        # Transformers models always output tuples.
        # See the models docstrings for the detail of all the outputs
        # In our case, the first element is the hidden state of the last layer of the Bert model
        encoded_layers = outputs[0]
        
    return indexed_tokens[1:-1], encoded_layers[0, 1:-1, :]
    

In [16]:
def doc_embedding(doc):
    sents = doc.split(' . ')
    doc_tokens = []
    doc_emb = []
    for i, s in enumerate(sents):
        if i == len(sents)-1:
            tok, emb = get_embedding(s, False)
        else:
            tok, emb = get_embedding(s)
        doc_tokens.extend(tok)
        doc_emb.append(emb)
    
    doc_emb = torch.cat(doc_emb, dim=0)
#     print(doc_emb.shape)
#     print(len(doc_tokens))
    return doc_tokens, doc_emb
        
    
        

In [17]:
# def score(cand, doc_tokens, all_probs_start, all_probs_end):
#     """:param cand: list of tokens in a candidate answer
#        :param doc_tokens: list of tokens in the document
#        :param allprobs: tensor of probabilities
#     """
#     score = 0
#     for i, t in enumerate(doc_tokens):
#         j = i+len(cand)-1
#         if j < len(doc_tokens) and t == cand[0] and doc_tokens[j] == cand[-1]:
#             score += all_probs_start[i]*all_probs_end[j]
            
#     return score

In [18]:
def score(candidates, map_cands, doc_tokens, all_probs_start, all_probs_end, average=True):
    """:param map_cands: map of first token of a candidate answer to its position in candidates
       :param candidates: list of list of tokens in candidate answers
       :param doc_tokens: list of tokens in the document
       :param allprobs: tensor of probabilities
       :return: score of each candidate answer normalized over candidates
    """
    scores = [0]*len(candidates)
    counts = [0]*len(candidates)
    for i, t in enumerate(doc_tokens):
        for c in map_cands.get(t, []):
            cand = candidates[c]
            j = i+len(cand)-1
            if j < len(doc_tokens) and t == cand[0] and doc_tokens[j] == cand[-1]:
                scores[c] += (all_probs_start[i]*all_probs_end[j]).item()
#                 print(scores)
                counts[c] += 1
    if average:
        return np.array(scores)/np.array(counts)
    return np.array(scores)

In [19]:
def full_answer(answer, query, doc):
    """:param answer: potential answer
       :param query:
       :param doc: doc with entities marked
    """
    if len(answer.split()) > 1: #multiple entity
        return answer
    
    abbreviation = '( {} )'.format(answer)
    if abbreviation in query: #need to find better answer
        define_ind = doc.index(abbreviation)
        prev_words = doc[:define_ind].split()
        for j in range(len(prev_words)-1, -1, -1):
            if prev_words[j].startswith("@entity"):
                return ent_to_plain(prev_words[j])
            
    return answer
        

In [20]:
def get_answers(document, query, candidates):
    # get query embedding
#     query = query.replace('▶ ', '').replace('@placeholder', '[MASK]')
#     query_tokens, query_emb = get_embedding(query)
#     mask_ind = query_tokens.index(tokenizer.convert_tokens_to_ids(['[MASK]'])[0])
#     mask_emb = query_emb[mask_ind:mask_ind+1, :]

    query = query.replace('▶ ', '').replace('@placeholder', '[MASK] [MASK]')
    query_tokens, query_emb = get_embedding(query)
    mask_ind_start = query_tokens.index(tokenizer.convert_tokens_to_ids(['[MASK]'])[0])
    mask_emb_start = query_emb[mask_ind_start:mask_ind_start+1, :]
    mask_ind_end = query_tokens.index(tokenizer.convert_tokens_to_ids(['[MASK]'])[0])+1
    mask_emb_end = query_emb[mask_ind_end:mask_ind_end+1, :]
        
    # get document embeddings
    doc_tokens, doc_emb = doc_embedding(ent_to_plain_doc(document))
    
    dot_product_start = torch.mm(mask_emb_start, torch.transpose(doc_emb, 0, 1))
    dot_product_end = torch.mm(mask_emb_end, torch.transpose(doc_emb, 0, 1))
    all_probs_start = torch.nn.functional.softmax(dot_product_start, dim = 1).reshape(-1)
    all_probs_end = torch.nn.functional.softmax(dot_product_end, dim = 1).reshape(-1)
    
    # get the candidate answers position and embeddings
    cand_ans = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(c)) for c in candidates]
    
    maps_cand = collections.defaultdict(list)
    for i, ca in enumerate(cand_ans):
        maps_cand[ca[0]].append(i)
            
#     print(maps_cand.get(2569, []))
    cand_ans_prob = score(cand_ans, maps_cand, doc_tokens, all_probs_start, all_probs_end, True)
    
#     cand_ans_prob = [score(c, doc_tokens, all_probs_start, all_probs_end) for c in cand_ans]


#     print(cand_ans_prob)
    ans_ind = np.argmax(cand_ans_prob)
#     print(ans_ind)
    
    answer = candidates[ans_ind]
    
    return full_answer(answer, query, document)
    
    

In [21]:
# doc_tok, doc_emb = doc_embedding(document)
# print(len(doc_tok))
# print(doc_emb.shape)

In [None]:
towrite = []
for pt in tqdm(sample_data):
    document, query, candidates, answer = pt['p'], pt['q'], pt['c'], pt['a']
    predicted = get_answers(document, query, candidates)
    towrite.append('{}::{}\n'.format(predicted, answer))

  2%|▏         | 73/4147 [06:50<8:00:13,  7.07s/it]

In [None]:
with open('../results/test_averaged.txt', 'w') as f:
    f.write(''.join(towrite))
    f.close()