In [4]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
from termcolor import colored
import random
from collections import Counter, defaultdict
from viterbi_trellis import ViterbiTrellis

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:

class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [63]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        #idx_to_mask = [arg1_start, arg2_start, arg1_end, arg2_end]
        H, _, _, orig2tok = model.encode(s, layers = layers)

        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
        
    return arg1_mean, arg2_mean
    

def print_sentence_nicely(sentence: str, ind1, ind2):
    
    arg1_sign = "**"
    arg2_sign = "++"
    arg1_color, arg2_color = "red", "blue"
    
    if not ind2 > ind1: 
        ind1, ind2 = ind2, ind1
        arg1_sign, arg2_sign = arg2_sign, arg1_sign
        arg1_color, arg2_color = arg2_color, arg1_color
        
    splitted = sentence.split(" ")
    before_arg1 = " ".join(splitted[:ind1])
    arg1 = splitted[ind1]
    arg2 = splitted[ind2]
    between = " ".join(splitted[ind1 + 1: ind2])
    suffix = " ".join(splitted[ind2+1:])
    
    return before_arg1 +  " " + arg1_sign + colored(arg1, arg1_color) + arg1_sign + " " + between + " " + arg2_sign + colored(arg2, arg2_color) + arg2_sign + " " + suffix



def main(filename, layers = [-1], num_results_to_print = 75):
    
    
    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 100, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
    
    #print(color.BOLD + "QUERY" + color.END + ":\n{}".format(query))
    
    first, first_ind1, first_ind2 = spike_results["sentence_text"].tolist()[-1], int(spike_results["arg1_first_index"].tolist()[-1]), int(spike_results["arg2_first_index"].tolist()[-1])
    #print(color.BOLD + "\nFIRST SPIKE RESULT" + color.END + ":\n{}".format(print_sentence_nicely(first, first_ind1, first_ind2)))
    #print(color.BOLD + "\nAUGMENTATION RESULTS:\n" + color.END)
    
    representations = []
    mappings_to_orig = []
    mappings_to_tok = []
    tokenized_txts = []
    orig_sents = []
    
    for i,s in enumerate(results1):
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        orig_sents.append(s)
        representations.append(H)
        mappings_to_orig.append(tok_to_orig_map)
        mappings_to_tok.append(orig2tok) 
        tokenized_txts.append(tokenized_text)
        
        if i > num_results_to_print: break
    
    return query, (arg1_rep, arg2_rep), (representations, mappings_to_orig, mappings_to_tok, tokenized_txts, orig_sents)


def get_between_tokens_similarity(padded_representations):
    
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))
    sims = cosine_similarity(padded_representations, padded_representations)
    sims = sims.reshape((num_sents, seq_len, num_sents, seq_len))
    
    return sims

def get_between_token_similarity_prev_sentence(padded_representations):
    
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))    
    padded_representations /= np.linalg.norm(padded_representations, keepdims = True, axis=1)
    padded_representations = padded_representations.reshape((num_sents, seq_len, bert_dim))
    tuples = list(zip(padded_representations, padded_representations[1:]))
    sims = np.array([tup[0].dot(tup[1].T) for tup in tuples])
    sims = sims.reshape(((num_sents-1)*seq_len, seq_len))
    return sims

def get_similarity_to_arguments(padded_representations, arg1, arg2):
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))
    #print(padded_representations.shape)
    sims = cosine_similarity([arg1_rep, arg2_rep], padded_representations)
    sims = sims.reshape((2, num_sents, seq_len))
    return sims

In [3]:
model = bert.BertEncoder("cuda", "scibert")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d8

In [68]:
from viterbi_trellis import ViterbiTrellis
grid = [[(i,j) for j in range(seq_len)] for i in range(num_sents)]
ARG_IND = 1

def state_score_func(sent_ind_tok_ind: tuple, alpha: float = 2,alpha2=1, arg_ind: int = ARG_IND):
    
    sent_ind, tok_ind = sent_ind_tok_ind
    return -alpha*sims_args[arg_ind][sent_ind][tok_ind] + alpha2*sims_args[1 if arg_ind==0 else 0][sent_ind][tok_ind]

def transition_score_func(sent_ind_tok_ind_source, sent_ind_tok_ind_dest):
    
    sent_ind_source, tok_ind_source = sent_ind_tok_ind_source
    sent_ind_dest, tok_ind_dest = sent_ind_tok_ind_dest
        
    return -sims_token[sent_ind_source,tok_ind_source,sent_ind_dest,tok_ind_dest]

def run_multiple_random_hmms(num_sents, n = 10):
    
    sent2captures = defaultdict(list)
    
    for i in range(n):
        
        ordering = list(range(num_sents))
        random.shuffle(ordering)
        grid_ordered = [x for _,x in sorted(zip(ordering,grid))]
        #print(grid_ordered[0][:6])

        new2orig = {i:grid_ordered[i][0][0] for i in range(len(grid_ordered))}
        
        v = ViterbiTrellis(grid_ordered, state_score_func, transition_score_func)
        best_path = v.viterbi_best_path()
        for j, token_ind in enumerate(best_path):
            
            sent_ind = new2orig[j]
            sent2captures[sent_ind].append(token_ind)
    
    return sent2captures

    for sentind in sent2captures.keys():
        counter = Counter(sent2captures[sentind])
        print(counter.most_common(2))
    

In [72]:
for q in range(1,10):
    
    # collect representations of arguments & results
    
    fname = "results{}.txt".format(q)
    query, (arg1_rep, arg2_rep), (representations, mappings_to_orig, mappings_to_tok, tokenized_txts, orig_sents)  = main(fname, layers = [-1])
    for i in range(len(representations)): # zero cls, ., sep
        representations[i][0][:] = np.random.rand()
        representations[i][-1][:] = np.random.rand()
        representations[i][-2][:] = np.random.rand()
    
    
    _, (arg1_rep0, arg2_rep0), (representations0, _, _, _, _)  = main(fname, layers = [0])
    for i in range(len(representations0)): # zero cls, ., sep
        representations0[i][0][:] = np.random.rand()
        representations0[i][-1][:] = np.random.rand()
        representations0[i][-2][:] = np.random.rand()
    
    # pad and collect similarit matrices
    
    pad_width = max([len(s) for s in representations])
    padded_representations = np.array([np.concatenate([r, -np.ones((pad_width-len(r), 768))]) for r in representations])
    padded_representations0 = np.array([np.concatenate([r, -np.ones((pad_width-len(r), 768))]) for r in representations0])

    num_sents, seq_len, bert_dim = padded_representations.shape
    num_tokens = num_sents * seq_len
    sims_token = get_between_tokens_similarity(padded_representations)
    sims_args = get_similarity_to_arguments(padded_representations, arg1_rep, arg2_rep)
    sims_token0 = get_between_tokens_similarity(padded_representations0)
        
    # run viterbi, naive
    
    grid = [[(i,j) for j in range(seq_len)] for i in range(num_sents)]
    v = ViterbiTrellis(grid, state_score_func, lambda x,y: 0)
    best_path_naive = v.viterbi_best_path()
    
    # run viterbi, with state transition cost
    
    sent2captures = run_multiple_random_hmms(num_sents,n=100)
    
    # evaluate
    print()
    print()
    print(color.BOLD+color.BLUE+"QUERY {}: {}".format(q, query)+color.END)
    print(color.BOLD+"==========================================================="+color.END)
    print(color.BOLD+"==========================================================="+color.END)

    for sent_ind,j_naive,orig_sent,tok2orig in zip(sorted(sent2captures.keys()),best_path_naive,orig_sents,mappings_to_orig):    
        if sent_ind > 15: break
            
        print(orig_sent)
        for k, (j, freq) in enumerate(Counter(sent2captures[sent_ind]).most_common(4)):
            if j in tok2orig:
              
                print(color.BOLD+color.GREEN+"ARG{}, option {}: {}".format(ARG_IND+1,k, orig_sent.split(" ")[tok2orig[j]])+color.END)
        #print("ARG1 NAIVE: {}".format(orig_sent.split(" ")[tok2orig[j_naive]]))
        if j_naive in tok2orig:
             print(color.BOLD+color.PURPLE+"ARG{}, naive: {}".format(ARG_IND+1,orig_sent.split(" ")[tok2orig[j_naive]])+color.END)
    
        else:
            print("none")
    
        print("---------------------------------------------")

(4620, 768)


[1m[94mQUERY 1: <>arg1:virus $infection $causes a <>arg2:condition .[0m
Hepatotropic virus , like MHV-3 infection in mice , can induce exaggerated inflammation in the liver and cause life-threatening viral FH .
[1m[92mARG2, option 0: inflammation[0m
[1m[92mARG2, option 1: can[0m
[1m[92mARG2, option 2: in[0m
[1m[95mARG2, naive: inflammation[0m
---------------------------------------------
In humans , CHIKV infections cause a debilitating disease with acute febrile illness and long-term polyarthralgia .
[1m[92mARG2, option 0: disease[0m
[1m[92mARG2, option 1: illness[0m
[1m[92mARG2, option 2: and[0m
[1m[92mARG2, option 3: acute[0m
[1m[95mARG2, naive: disease[0m
---------------------------------------------
MERS-CoV induces acute pneumonia similar to that caused by SARS-CoV , and is sometimes accompanied with renal failure ( Danielsson and Catchpole , 2012 ; Zaki et al. , 2012 ) .
[1m[92mARG2, option 0: pneumonia[0m
[1m[92mARG2, option 1: a

(4851, 768)


[1m[94mQUERY 3: a arg1:subset of patients $progress to arg2:hemorrhagic fever[0m
In severe cases , dyspnea and/or hypoxemia usually occur one week after the onset of the disease , and in severe cases , it rapidly progresses to acute respiratory distress syndrome , septic shock , metabolic acidosis that is difficult to correct , and bleeding and coagulation dysfunction .
[1m[92mARG2, option 0: severe[0m
[1m[92mARG2, option 1: to[0m
[1m[92mARG2, option 2: acute[0m
[1m[92mARG2, option 3: disease[0m
[1m[95mARG2, naive: severe[0m
---------------------------------------------
In severe cases , the disease may progress to respiratory , circulatory , and renal failure , and ultimately death due to multiorgan failure .
[1m[92mARG2, option 0: severe[0m
[1m[92mARG2, option 1: respiratory[0m
[1m[92mARG2, option 2: to[0m
[1m[92mARG2, option 3: may[0m
[1m[95mARG2, naive: severe[0m
---------------------------------------------
While disease is often mild

(5159, 768)


[1m[94mQUERY 5: arg1:[e]paracetamol is $not useful for treating arg2:[e]asthma.[0m
However , while tocilizumab is a promising agent against COVID-19 , it is not an appropriate agent in patients with active or latent tuberculosis , bacterial and fungal infections , multi-organ failure , and gastrointestinal perforation [ 7 ] .
[1m[92mARG2, option 0: [[0m
[1m[92mARG2, option 1: patients[0m
[1m[92mARG2, option 2: tuberculosis[0m
[1m[92mARG2, option 3: However[0m
[1m[95mARG2, naive: tuberculosis[0m
---------------------------------------------
There are cases of ADEM or even fulminant presentation such as AHL where steroids alone are not sufficient for suppressing inflammation and improving clinical findings .
[1m[92mARG2, option 0: inflammation[0m
[1m[95mARG2, naive: inflammation[0m
---------------------------------------------
Thus , ribavirin may not be useful for treating SARS infections because of its questionable efficacy and because of its known

(5313, 768)


[1m[94mQUERY 7: the recommended arg1:[w]quarantine period is arg2:14 :[w]days[0m
If we aim to control the failure rate of quarantine to be below 1 % with 95 % confidence , then the quarantine period must be at least 22 days .
[1m[92mARG2, option 0: 22[0m
[1m[92mARG2, option 1: days[0m
[1m[95mARG2, naive: 22[0m
---------------------------------------------
In general , the duration of the quarantine period should be 21 to 30 days .
[1m[92mARG2, option 0: 30[0m
[1m[92mARG2, option 1: days[0m
[1m[92mARG2, option 2: 21[0m
[1m[92mARG2, option 3: duration[0m
[1m[95mARG2, naive: 30[0m
---------------------------------------------
Although for extreme cases , the quarantine period should be extended up to three weeks .
[1m[92mARG2, option 0: three[0m
[1m[92mARG2, option 1: weeks[0m
[1m[92mARG2, option 2: the[0m
[1m[92mARG2, option 3: extreme[0m
[1m[95mARG2, naive: three[0m
---------------------------------------------
At the beginning and

(6776, 768)


[1m[94mQUERY 9: arg1:[e]COVID-19 activates the arg2:ATP $receptor[0m
Human angiotensin-converting enzyme 2 ( ACE2 ) is a functional receptor hijacked by SARS-CoV-2 for cell entry , similar to SARS-CoV [ 8 , 16 ] .
[1m[92mARG2, option 0: enzyme[0m
[1m[92mARG2, option 1: ACE2[0m
[1m[92mARG2, option 2: receptor[0m
[1m[92mARG2, option 3: ([0m
[1m[95mARG2, naive: enzyme[0m
---------------------------------------------
Molecularly , like SARS-CoV , the SARS-CoV-2 virus likely uses ACE-2 as entry receptor , which is highly expressed in the lung and gastrointestinal tract [ 10 ] [ 11 ] [ 12 ] .
[1m[92mARG2, option 0: receptor[0m
[1m[92mARG2, option 1: ACE-2[0m
[1m[92mARG2, option 2: uses[0m
[1m[92mARG2, option 3: uses[0m
[1m[95mARG2, naive: receptor[0m
---------------------------------------------
SARS-CoV-2 enters respiratory epithelial cells by attaching to angiotensin converting enzyme-2 ( ACE-2 ) via S-protein ; ACE-2 is also a receptor for SA