In [495]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
from termcolor import colored

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [455]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [493]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        #idx_to_mask = [arg1_start, arg2_start, arg1_end, arg2_end]
        H, _, _, orig2tok = model.encode(s, layers = layers)

        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
        
    return arg1_mean, arg2_mean
    

def print_sentence_nicely(sentence: str, ind1, ind2):
    
    arg1_sign = "**"
    arg2_sign = "++"
    arg1_color, arg2_color = "red", "blue"
    
    if not ind2 > ind1: 
        ind1, ind2 = ind2, ind1
        arg1_sign, arg2_sign = arg2_sign, arg1_sign
        arg1_color, arg2_color = arg2_color, arg1_color
        
    splitted = sentence.split(" ")
    before_arg1 = " ".join(splitted[:ind1])
    arg1 = splitted[ind1]
    arg2 = splitted[ind2]
    between = " ".join(splitted[ind1 + 1: ind2])
    suffix = " ".join(splitted[ind2+1:])
    
    return before_arg1 +  " " + arg1_sign + colored(arg1, arg1_color) + arg1_sign + " " + between + " " + arg2_sign + colored(arg2, arg2_color) + arg2_sign + " " + suffix



def main(filename, layers = [-1]):
    

    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 50, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
    
    print(color.BOLD + "QUERY" + color.END + ":\n{}".format(query))
    
    first, first_ind1, first_ind2 = spike_results["sentence_text"].tolist()[-1], int(spike_results["arg1_first_index"].tolist()[-1]), int(spike_results["arg2_first_index"].tolist()[-1])
    print(color.BOLD + "\nFIRST SPIKE RESULT" + color.END + ":\n{}".format(print_sentence_nicely(first, first_ind1, first_ind2)))
    print(color.BOLD + "\nAUGMENTATION RESULTS:\n" + color.END)
    
    for s in results1:
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        #print("H shape: {}".format(H.shape))
        sims_arg1 = cosine_similarity([arg1_rep], H[1:-2])[0]
        sims_arg2 = cosine_similarity([arg2_rep], H[1:-2])[0]
        arg1_ind = np.argmax(sims_arg1) + 1
        
        sims_arg2[np.argmax(sims_arg1)] = -1
        arg2_ind = np.argmax(sims_arg2) + 1
        #print(arg1_ind, arg2_ind)
        if arg1_ind not in tok_to_orig_map or arg2_ind not in tok_to_orig_map:
            print("ERROR")
            continue
        print(print_sentence_nicely(s, tok_to_orig_map[arg1_ind], tok_to_orig_map[arg2_ind]))
        print("-----------------------------------")
    
    print(color.BOLD + "=======================================================" + color.END)


In [457]:
model = bert.BertEncoder("cpu", "scibert")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d8

# Alignment Experiments

In [514]:
main("results1.txt", layers = [-1])

[1mQUERY[0m:
<>arg1:virus $infection $causes a <>arg2:condition .
[1m
FIRST SPIKE RESULT[0m:
178 Thus , **[31mleishmania[0m** infection causes ++[34msecondary[0m++ infections , and activation/reactivation of dormant infections that can lead to severe or lethal outcomes through multiple mechanisms .
[1m
AUGMENTATION RESULTS:
[0m
In humans , **[31mJEV[0m** infection can cause Japanese encephalitis ( JE ) with ++[34msevere[0m++ central nervous system disorders .
-----------------------------------
Hepatotropic **[31mvirus[0m** , like MHV-3 infection in mice , can induce exaggerated ++[34minflammation[0m++ in the liver and cause life-threatening viral FH .
-----------------------------------
Human metapneumovirus ( hMPV ) **[31minfection[0m** causes respiratory tract ++[34mdisease[0m++ similar to that observed during human respiratory syncytial virus infection ( hRSV ) .
-----------------------------------
In humans , CHIKV **[31minfections[0m** cause a debilitating

Accuracy: 20/20

In [511]:
main("results2.txt", [-1])

[1mQUERY[0m:
infected arg1:patients $usually $develop <>arg2:[entity]respiratory illness.
[1m
FIRST SPIKE RESULT[0m:
12 , 27 , [ 33 ] [ 34 ] [ 35 ] **[31mPuppies[0m** that are nursing from CHV-seronegative dams usually develop the fatal ++[34mmultisystemic[0m++ disease , while puppies that suckle from CHV-seropositive dams remain asymptomatic but still become infected .
[1m
AUGMENTATION RESULTS:
[0m
After an incubation period of 5 to 14 days , SARS-CoV-2-infected **[31mpeople[0m** commonly manifest features of ++[34mpneumonia[0m++ , including fever , dry cough , dyspnoea , myalgia and fatigue .
-----------------------------------
After an incubation period of 2 to 3 days , **[31mpatients[0m** who have pneumonic plague typically develop fulminant ++[34mpneumonia[0m++ , with malaise , high fever , cough , hemoptysis , and septicemia with ecchymoses and extremity necrosis .
-----------------------------------
 **[31mPatients[0m** suffering from severe DENV infection of

Accuracy: 19/20

In [510]:
main("results3.txt", [-1])

[1mQUERY[0m:
a arg1:subset of patients  $progress to arg2:hemorrhagic fever
[1m
FIRST SPIKE RESULT[0m:
The **[31mdisease[0m** may progress across the bone of the skull base to the ++[34mjugular[0m++ foramen and hypoglossal canal , resulting in palsies of the ninth , tenth , eleventh and twelfth cranial nerves .
[1m
AUGMENTATION RESULTS:
[0m
 ++[34mSevere[0m++  **[31mcases[0m** may progress to a capillary leak syndrome with septic shock , rash , facial and neck swelling , and multi-organ system failure .
-----------------------------------
When symptoms occur , IPS **[31mmay[0m** rapidly progress to ++[34mpulmonary[0m++ dysfunction requiring mechanical ventilation .
-----------------------------------
 ++[34mClinical[0m++ signs are nonspecific in the early stages , but as the **[31mdisease[0m** progresses to involve more of the liver and impair regeneration , icterus , ascites , and hepatic encephalopathy may develop as typical correlates with hepatic insufficiency

Accuarcy: 13/20

In [509]:
main("results4.txt", [-1])

[1mQUERY[0m:
people with arg1:[e]asthma are $susceptible to arg2:[e]asthma
[1m
FIRST SPIKE RESULT[0m:
However , children with inborn **[31merrors[0m** in TLR3 signaling ( 54 ) ( 55 ) ( 56 ) or defects resulting in abnormal signaling through several antiviral TLRs including 3 , 7 , 8 , 9 ( e.g. UNC-93B deficiency)(57 ) are susceptible to ++[34mHSVE[0m++ .
[1m
AUGMENTATION RESULTS:
[0m
Patients with lung **[31mcancer[0m** are highly susceptible to ++[34minfection[0m++ compared to healthy individuals because of systemic immunosuppression induced by malignancy and anticancer therapy .
-----------------------------------
Patients with lung **[31mcancer[0m** are highly susceptible to ++[34minfection[0m++ compared to healthy individuals because of systemic immunosuppression induced by malignancy and anticancer therapy .
-----------------------------------
 **[31mPatients[0m** with ARDS triggered by viral ++[34minfection[0m++ , in particular influenza , are prone to invas

Accuracy: 16/20

In [508]:
main("results5.txt", [-1])

[1mQUERY[0m:
arg1:[e]paracetamol is $not useful for treating arg2:[e]asthma.
[1m
FIRST SPIKE RESULT[0m:
There are cases of ADEM or even fulminant presentation such as AHL where **[31msteroids[0m** alone are not sufficient for suppressing ++[34minflammation[0m++ and improving clinical findings .
[1m
AUGMENTATION RESULTS:
[0m
However , while tocilizumab is a promising agent against COVID-19 , **[31mit[0m** is not an appropriate agent in patients with active or latent ++[34mtuberculosis[0m++ , bacterial and fungal infections , multi-organ failure , and gastrointestinal perforation [ 7 ] .
-----------------------------------
There are cases of ADEM or even fulminant presentation such as AHL where **[31msteroids[0m** alone are not sufficient for suppressing ++[34minflammation[0m++ and improving clinical findings .
-----------------------------------
Thus , **[31mribavirin[0m** may not be useful for treating ++[34mSARS[0m++ infections because of its questionable efficac

Accuracy: 15/20

In [507]:
main("results6.txt", [-1])

[1mQUERY[0m:
arg1:[e]stroke is a $complication of :[e]COVID-19 arg2:infection
[1m
FIRST SPIKE RESULT[0m:
In conclusion , hepatic **[31minjury[0m** is likely a complication of COVID-19 ++[34minfection[0m++ .
[1m
AUGMENTATION RESULTS:
[0m
Acute kidney **[31minjury[0m** is a possible manifestation in ++[34msevere[0m++ forms of COVID-19 , conferring a high mortality 7 .
-----------------------------------
Coagulation dysfunction is one of the major causes for **[31mdeath[0m** in ++[34mpatients[0m++ with severe COVID-19 [ 4 ] .
-----------------------------------
ERROR
 **[31mHypoxia[0m** is the most common manifestation of ++[34mpatients[0m++ with COVID-19 , especially in severely or critically ill patients .
-----------------------------------
Early-onset **[31mpneumonia[0m** is a common and severe complication that is related to aspiration in ++[34mpatients[0m++ with GBS .
-----------------------------------
 **[31mBarotrauma[0m** is a common complication in +

Accuracy: 16/20

In [505]:
main("results7.txt", [-1])

[1mQUERY[0m:
the recommended arg1:[w]quarantine period is arg2:14 :[w]days
[1m
FIRST SPIKE RESULT[0m:
The CERs of interventions was unstable when the **[31mquarantine[0m** delay-time was no less than ++[34mfive[0m++ days ( figure 2d ) .
[1m
AUGMENTATION RESULTS:
[0m
If we aim to control the failure rate of quarantine to be below 1 % with 95 % confidence , then the **[31mquarantine[0m** period must be at least ++[34m22[0m++ days .
-----------------------------------
In general , the duration of the **[31mquarantine[0m** period should be 21 to ++[34m30[0m++ days .
-----------------------------------
Although for extreme cases , the **[31mquarantine[0m** period should be extended up to ++[34mthree[0m++ weeks .
-----------------------------------
At the beginning and the end of a SARS epidemic , to control the epidemic completely , we recommend a period of ++[34m22[0m++ days for **[31mquarantine[0m** , which would capture 99 percent of all probable SARS cases with

Accuracy: 19/20

In [504]:
main("results8.txt", [-1])

[1mQUERY[0m:
arg1:[e]COVID-19 $infects arg2:cells
[1m
FIRST SPIKE RESULT[0m:
Besides the lung , the **[31mSARS-CoV-2[0m** virus also infects the ++[34mGI[0m++ tract , causing the patients also to experience diarrhea .
[1m
AUGMENTATION RESULTS:
[0m
 **[31mSARS-CoV-2[0m** infects ACE2positive ++[34mcells[0m++ in the oral mucosa and lungs , including ACE2 + AT2 cells in the alveoli .
-----------------------------------
SARS-CoV targets human angiotensin-converting enzyme **[31m2[0m** ( hACE2 ) and infects intrapulmonary epithelial ++[34mcells[0m++ more than cells of the upper airways [ 13 , 14 ] .
-----------------------------------
Initially , **[31mSARS-CoV-2[0m** infects ++[34mcells[0m++ in the respiratory system and causes inflammation and cell death .
-----------------------------------
 **[31mSARS-CoV-2[0m** infects ACE2 + ++[34mcells[0m++ in the oral mucosa and lungs , including ACE-2 cells in the alveoli [ 97 ] .
-----------------------------------
In the

Accuracy: 17/20

In [503]:
main("results9.txt", [-1])

[1mQUERY[0m:
arg1:[e]COVID-19 activates the arg2:ATP $receptor
[1m
FIRST SPIKE RESULT[0m:
 **[31mSARS-CoV-2[0m** uses the angiotensin converting ++[34menzyme[0m++ 2 ( ACE2 ) receptor to enter the host and this receptor is highly expressed in both the respiratory and gastrointestinal tract 9 , 10 , 11 .
[1m
AUGMENTATION RESULTS:
[0m
 **[31mSARS-CoV-2[0m** enters respiratory epithelial cells by attaching to angiotensin converting enzyme-2 ( ACE-2 ) via S-protein ; ++[34mACE-2[0m++ is also a receptor for SARS-CoV-1 [ 27 ] .
-----------------------------------
Presumably , SARS-CoV infects cells through the S protein , which binds to cell surface receptor-angiotensin-converting ++[34menzyme[0m++  **[31m2[0m** [ 4 ] .
-----------------------------------
Human angiotensin-converting enzyme 2 ( ++[34mACE2[0m++ ) is a functional receptor hijacked by **[31mSARS-CoV-2[0m** for cell entry , similar to SARS-CoV [ 8 , 16 ] .
-----------------------------------
It is now known 

Accuracy: 19/20

In [502]:
main("results10.txt", [-1])

[1mQUERY[0m:
low blood oxygen arg1:levels are a $risk factor in arg2:[e]COVID-19 infection
[1m
FIRST SPIKE RESULT[0m:
Among all univariable parameters , **[31mAPACHE[0m** II , SOFA , lymphocytes , CRP , LDH , AST , cTnI , BNP , et al were significantly independent risk factors of ++[34mCOVID-19[0m++ severity .
[1m
AUGMENTATION RESULTS:
[0m
Similar to other studies ， our data indicated senior and hypertension are the high risk factors for severe ++[34mCOVID-19[0m++  **[31minfection[0m** .
-----------------------------------
These findings suggested that **[31mmen[0m** with severe ++[34mCOVID-19[0m++ is susceptible to secondary infections with virus or bacteria , resulting in higher utilization rate of advanced antiviral therapy and antibiotics .
-----------------------------------
In this study , preceding RVIs was one of the independent risk factors associated with severe ++[34mpneumococcal[0m++  **[31mpneumonia[0m** ( a higher PSI score ≥91 ) , suggesting a poten

Accuarrcy: 12/20