In [287]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
from termcolor import colored

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [242]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [299]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        H, _, _, orig2tok = model.encode(s, layers = layers)
        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.median(arg1_rep, axis = 0)
    arg2_mean = np.median(arg2_rep, axis = 0)
    
    return arg1_mean, arg2_mean
    

def print_sentence_nicely(sentence: str, ind1, ind2):
    
    arg1_sign = "**"
    arg2_sign = "++"
    
    if not ind2 > ind1: 
        nd1, ind2 = ind2, ind1
        arg1_sign, arg2_sign = arg2_sign, arg1_sign
        
    splitted = sentence.split(" ")
    before_arg1 = " ".join(splitted[:ind1])
    arg1 = splitted[ind1]
    arg2 = splitted[ind2]
    between = " ".join(splitted[ind1 + 1: ind2])
    suffix = " ".join(splitted[ind2+1:])
    
    return before_arg1 +  " " + arg1_sign + colored(arg1, "red") + arg1_sign + " " + between + " " + arg2_sign + colored(arg2, "blue") + arg2_sign + " " + suffix



def main(filename, layers = [-1]):
    

    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 50, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
    
    print(color.BOLD + "QUERY" + color.END + ":\n{}".format(query))
    print(color.BOLD + "RESULTS:\n" + color.END)
    for s in results1:
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        sims_arg1 = cosine_similarity([arg1_rep], H[1:-2])[0]
        sims_arg2 = cosine_similarity([arg2_rep], H[1:-2])[0]
        arg1_ind = np.argmax(sims_arg1) + 1
        arg2_ind = np.argmax(sims_arg2) + 1
        print(arg1_ind, arg2_ind)
        print(print_sentence_nicely(s, tok_to_orig_map[arg1_ind], tok_to_orig_map[arg2_ind]))
        print("-----------------------------------")
    
    print(color.BOLD + "=======================================================" + color.END)


In [292]:
model = bert.BertEncoder("cpu")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config AlbertConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "num_attention_heads": 12,
  "num_hidden_groups": 1,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_s

In [301]:
main("results1.txt", layers = [-1])

[1mQUERY[0m:
<>arg1:virus $infection $causes a <>arg2:condition .
[1mRESULTS:
[0m
18 18
In humans , JEV infection can cause Japanese encephalitis ( JE ) with severe central nervous ++[31msystem[0m++  **[34msystem[0m** disorders .
-----------------------------------
28 6
Hepatotropic virus , like MHV-3 infection in mice , can induce exaggerated inflammation in the liver and cause life-threatening viral ++[31mFH[0m++  **[34mFH[0m** .
-----------------------------------
22 22
Human metapneumovirus ( hMPV ) infection causes respiratory tract disease similar to that observed during human ++[31mrespiratory[0m++  **[34mrespiratory[0m** syncytial virus infection ( hRSV ) .
-----------------------------------
12 12
In humans , CHIKV infections cause a ++[31mdebilitating[0m++  **[34mdebilitating[0m** disease with acute febrile illness and long-term polyarthralgia .
-----------------------------------
3 19
FIV **[31minfection[0m** causes progressive immunosuppression that re

In [258]:
main("results2.txt")

[1mQUERY[0m:
infected arg1:patients $usually $develop <>arg2:[entity]respiratory illness.
[1mRESULTS:
[0m
After an incubation period of 5 to 14 days , SARS-CoV-2-infected **[31mpeople[0m** commonly manifest features of ++[34mpneumonia[0m++ , including fever , dry cough , dyspnoea , myalgia and fatigue .
-----------------------------------
After an incubation period of 2 to 3 days , **[31mpatients[0m** who have pneumonic plague typically develop fulminant ++[34mpneumonia[0m++ , with malaise , high fever , cough , hemoptysis , and septicemia with ecchymoses and extremity necrosis .
-----------------------------------
 **[31mPatients[0m** suffering from severe DENV infection often exhibit encephalopathy and ++[34mencephalitis[0m++ .
-----------------------------------
Following an incubation period of usually 4 - 5 days , ++[31mpatients[0m++  **[34mpatients[0m** infected with SARS-CoV often present with symptoms of fever , headache , and myalgias .
--------------------

In [290]:
main("results3.txt")

[1mQUERY[0m:
a arg1:subset of patients  $progress to arg2:hemorrhagic fever
[1mRESULTS:
[0m
1 1
 ++[31mSevere[0m++  **[34mSevere[0m** cases may progress to a capillary leak syndrome with septic shock , rash , facial and neck swelling , and multi-organ system failure .
-----------------------------------
8 6
When symptoms occur , IPS may rapidly ++[31mprogress[0m++  **[34mprogress[0m** to pulmonary dysfunction requiring mechanical ventilation .
-----------------------------------
8 28
Clinical signs are nonspecific in the early **[31mstages[0m** , but as the disease progresses to involve more of the liver and impair regeneration , icterus ++[34m,[0m++ ascites , and hepatic encephalopathy may develop as typical correlates with hepatic insufficiency .
-----------------------------------
18 18
The patients can develop radiographic findings of bilateral pulmonary opacities [ 9 ] , and ++[31mthey[0m++  **[34mthey[0m** can progress to respiratory failure in severe cases .
