In [230]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
from termcolor import colored

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [235]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        H, _, _, orig2tok = model.encode(s, layers = [-1])
        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
    
    return arg1_mean, arg2_mean
    

def get_sentence_representation(sent, model):
    
     H, _, tok_to_orig_map, orig2tok = model.encode(s, layers = [-1])
     return H

def print_sentence_nicely(sentence: str, ind1, ind2):
    
    arg1_sign = "**"
    arg2_sign = "++"
    
    if not ind2 > ind1: 
        nd1, ind2 = ind2, ind1
        arg1_sign, arg2_sign = arg2_sign, arg1_sign
        
    splitted = sentence.split(" ")
    before_arg1 = " ".join(splitted[:ind1])
    arg1 = splitted[ind1]
    arg2 = splitted[ind2]
    between = " ".join(splitted[ind1 + 1: ind2])
    suffix = " ".join(splitted[ind2+1:])
    
    return before_arg1 +  " " + arg1_sign + colored(arg1, "red") + arg1_sign + " " + between + " " + arg2_sign + colored(arg2, "blue") + arg2_sign + " " + suffix


In [173]:
model = bert.BertEncoder("cpu")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

Model name 'allenai/scibert_scivocab_uncased' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-

In [225]:
query, results1 = load_results("results1.txt")

In [226]:
spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 25, query_type = "syntactic")
spike_results = spike_results[spike_results['sentence_text'].notnull()]

In [227]:
arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results)

In [236]:
for s in results1:
    H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = [-1])
    sims_arg1 = cosine_similarity([arg1_rep], H[1:-2])[0]
    sims_arg2 = cosine_similarity([arg2_rep], H[1:-2])[0]
    arg1_ind = np.argmax(sims_arg1) + 1
    arg2_ind = np.argmax(sims_arg2) + 1
    print(print_sentence_nicely(s, tok_to_orig_map[arg1_ind], tok_to_orig_map[arg2_ind]))


In humans , **[31mJEV[0m** infection can cause Japanese encephalitis ( JE ) with ++[34msevere[0m++ central nervous system disorders .
Hepatotropic **[31mvirus[0m** , like MHV-3 infection in mice , can induce exaggerated ++[34minflammation[0m++ in the liver and cause life-threatening viral FH .
Human metapneumovirus ( hMPV ) **[31minfection[0m** causes respiratory tract ++[34mdisease[0m++ similar to that observed during human respiratory syncytial virus infection ( hRSV ) .
In humans , CHIKV **[31minfections[0m** cause a debilitating ++[34mdisease[0m++ with acute febrile illness and long-term polyarthralgia .
 **[31mFIV[0m** infection causes ++[34mprogressive[0m++ immunosuppression that results in the development of the acquired immunodeficiency syndrome ( AIDS ) in cats , which , similar to human HIV infection , increases susceptibility to secondary and opportunistic infections .
 **[31mHuman[0m** infection of H7N9 avian infl uenza virus directly induces ++[34mpn

In [150]:
" ".join(tokenized_text).replace(" ##", "").replace("## ", "")

'[CLS] as discussed above , coxsackievirus b3 ( cvb3 ) infection causes myocarditis in human beings as well as in male balb / c mice . [SEP]'

In [151]:
results1[8]

'As discussed above , coxsackievirus B3 ( CVB3 ) infection causes myocarditis in human beings as well as in male BALB/c mice .'

In [153]:
orig2tok

{0: 1,
 1: 2,
 2: 3,
 3: 4,
 4: 9,
 5: 11,
 6: 12,
 7: 15,
 8: 16,
 9: 17,
 10: 18,
 11: 20,
 12: 21,
 13: 22,
 14: 23,
 15: 24,
 16: 25,
 17: 26,
 18: 27,
 19: 28,
 20: 31,
 21: 32,
 22: 33}

In [165]:
tok_to_orig_map = {}
keys = list(sorted(orig2tok.keys()))
for k, k2 in zip(keys, keys[1:]):
    for k3 in range(orig2tok[k], orig2tok[k2]):
        tok_to_orig_map[k3] = k

In [166]:
tok_to_orig_map

{1: 0,
 2: 1,
 3: 2,
 4: 3,
 5: 3,
 6: 3,
 7: 3,
 8: 3,
 9: 4,
 10: 4,
 11: 5,
 12: 6,
 13: 6,
 14: 6,
 15: 7,
 16: 8,
 17: 9,
 18: 10,
 19: 10,
 20: 11,
 21: 12,
 22: 13,
 23: 14,
 24: 15,
 25: 16,
 26: 17,
 27: 18,
 28: 19,
 29: 19,
 30: 19,
 31: 20,
 32: 21}

In [229]:
print("\x1b[31m\"red\"\x1b[0m")

[31m"red"[0m
