In [136]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
from termcolor import colored
import random
from collections import Counter, defaultdict

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [339]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        #idx_to_mask = [arg1_start, arg2_start, arg1_end, arg2_end]
        H, _, _, orig2tok = model.encode(s, layers = layers)

        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
        
    return arg1_mean, arg2_mean
    

def print_sentence_nicely(sentence: str, ind1, ind2):
    
    arg1_sign = "**"
    arg2_sign = "++"
    arg1_color, arg2_color = "red", "blue"
    
    if not ind2 > ind1: 
        ind1, ind2 = ind2, ind1
        arg1_sign, arg2_sign = arg2_sign, arg1_sign
        arg1_color, arg2_color = arg2_color, arg1_color
        
    splitted = sentence.split(" ")
    before_arg1 = " ".join(splitted[:ind1])
    arg1 = splitted[ind1]
    arg2 = splitted[ind2]
    between = " ".join(splitted[ind1 + 1: ind2])
    suffix = " ".join(splitted[ind2+1:])
    
    return before_arg1 +  " " + arg1_sign + colored(arg1, arg1_color) + arg1_sign + " " + between + " " + arg2_sign + colored(arg2, arg2_color) + arg2_sign + " " + suffix



def main(filename, layers = [-1], num_results_to_print = 50):
    
    
    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 100, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
    
    #print(color.BOLD + "QUERY" + color.END + ":\n{}".format(query))
    
    first, first_ind1, first_ind2 = spike_results["sentence_text"].tolist()[-1], int(spike_results["arg1_first_index"].tolist()[-1]), int(spike_results["arg2_first_index"].tolist()[-1])
    #print(color.BOLD + "\nFIRST SPIKE RESULT" + color.END + ":\n{}".format(print_sentence_nicely(first, first_ind1, first_ind2)))
    #print(color.BOLD + "\nAUGMENTATION RESULTS:\n" + color.END)
    
    representations = []
    mappings_to_orig = []
    mappings_to_tok = []
    tokenized_txts = []
    orig_sents = []
    
    for i,s in enumerate(results1):
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        orig_sents.append(s)
        representations.append(H)
        mappings_to_orig.append(tok_to_orig_map)
        mappings_to_tok.append(orig2tok) 
        tokenized_txts.append(tokenized_text)
        
        if i > num_results_to_print: break
    
    return query, (arg1_rep, arg2_rep), (representations, mappings_to_orig, mappings_to_tok, tokenized_txts, orig_sents)

In [3]:
model = bert.BertEncoder("cuda", "scibert")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d8

In [322]:
def get_between_tokens_similarity(padded_representations):
    
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))
    sims = cosine_similarity(padded_representations, padded_representations)
    sims = sims.reshape((num_sents, seq_len, num_sents, seq_len))
    
    return sims

def get_between_token_similarity_prev_sentence(padded_representations):
    
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))    
    padded_representations /= np.linalg.norm(padded_representations, keepdims = True, axis=1)
    padded_representations = padded_representations.reshape((num_sents, seq_len, bert_dim))
    tuples = list(zip(padded_representations, padded_representations[1:]))
    sims = np.array([tup[0].dot(tup[1].T) for tup in tuples])
    sims = sims.reshape(((num_sents-1)*seq_len, seq_len))
    return sims

def get_similarity_to_arguments(padded_representations, arg1, arg2):
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))
    print(padded_representations.shape)
    sims = cosine_similarity([arg1_rep, arg2_rep], padded_representations)
    sims = sims.reshape((2, num_sents, seq_len))
    return sims

In [323]:
26*52*52

70304

In [357]:
fname = "results6.txt"

query, (arg1_rep, arg2_rep), (representations, mappings_to_orig, mappings_to_tok, tokenized_txts, orig_sents)  = main(fname, layers = [-1])
for i in range(len(representations)): # zero cls, ., sep
    representations[i][0][:] = np.random.rand()
    representations[i][-1][:] = np.random.rand()
    representations[i][-2][:] = np.random.rand()
    
    
_, (arg1_rep0, arg2_rep0), (representations0, _, _, _, _)  = main(fname, layers = [3])
for i in range(len(representations0)): # zero cls, ., sep
    representations0[i][0][:] = np.random.rand()
    representations0[i][-1][:] = np.random.rand()
    representations0[i][-2][:] = np.random.rand()

In [358]:
query

'arg1:[e]stroke is a $complication of :[e]COVID-19 arg2:infection'

In [359]:
pad_width = max([len(s) for s in representations])
padded_representations = np.array([np.concatenate([r, np.ones((pad_width-len(r), 768))]) for r in representations])
padded_representations0 = np.array([np.concatenate([r, np.ones((pad_width-len(r), 768))]) for r in representations0])

num_sents, seq_len, bert_dim = padded_representations.shape
num_tokens = num_sents * seq_len
sims_token = get_between_tokens_similarity(padded_representations)
sims_args = get_similarity_to_arguments(padded_representations, arg1_rep, arg2_rep)
sims_token0 = get_between_tokens_similarity(padded_representations0)
#sims_token = get_between_token_similarity_prev_sentence(padded_representations)

"""
for i in range(num_sents):
        sims_args[0][i][0] = 0.0
        sims_args[1][i][0] = 0.0
        sims_args[0][i][-1] = 0.0
        sims_args[1][i][-1] = 0.0
        sims_args[0][i][-2] = 0.0
        sims_args[1][i][-2] = 0.0
"""         

(2236, 768)


'\nfor i in range(num_sents):\n        sims_args[0][i][0] = 0.0\n        sims_args[1][i][0] = 0.0\n        sims_args[0][i][-1] = 0.0\n        sims_args[1][i][-1] = 0.0\n        sims_args[0][i][-2] = 0.0\n        sims_args[1][i][-2] = 0.0\n'

In [64]:
sims_token.shape, sims_args.shape

((27, 52, 27, 52), (2, 27, 52))

In [65]:
sims_token.shape

(27, 52, 27, 52)

In [11]:
from collections import defaultdict

In [12]:
import cvxpy as cp
import numpy as np
from collections import defaultdict
import tqdm

num_variables = sims_args.shape[1]
#x = [[cp.Variable(1, boolean=True) for j in range(seq_len)] for i in range(num_sents)]
X = cp.Variable((num_sents, seq_len), boolean = True)


Q = cp.Variable(sims_token.shape)
similarity_to_arguments_component = cp.sum( cp.multiply(X, sims_args[1]))/(X.shape[0]*X.shape[1])
similarity_between_tokens_component = cp.sum(cp.multiply(Q, sims_token))/(sims_token.shape[0]*sims_token.shape[1])

In [13]:
similarity_to_arguments_component.value, similarity_between_tokens_component.value

(None, None)

In [14]:
Q.shape, sims_token.shape

((1352, 52), (1352, 52))

In [15]:
objective = cp.Maximize(similarity_to_arguments_component + 0.001*similarity_between_tokens_component)
#objective = cp.Maximize(cp.sum(X * sims_args[0]))
constraints = []

for i in range(num_sents):
    constraints.append(cp.sum(X[i]) <= 1)
    constraints.append(cp.sum(X[i]) >= 1)

for i in tqdm.tqdm_notebook(range(sims_token.shape[0])):
    y = i % seq_len # row ind
    x = i // seq_len  # col ind
    
    for j in range(seq_len):
       constraints.append(cp.sum([-X[x][y]-X[x][x+1]+Q[i][j]]) >= -1)
       constraints.append(cp.sum([X[x][y]-Q[i][j]]) >= 0)
       constraints.append(cp.sum([X[x][x+1]-Q[i][j]]) >= 0)
    
#constraints.append(cp.sum_squares(Q) <= 1e-6)
"""
for i in tqdm.tqdm_notebook(range(sims_token.shape[0])):
    y = i % seq_len # row ind
    x = i // seq_len  # col ind
    
    for j in range(sims_token.shape[0]):
        w = j % seq_len # second row ind
        z = j // seq_len  # second col ind
        #constraints.append(cp.sum([-X[x][y]-X[z][w]+Q[i][j]]) >= -1)
        #constraints.append(cp.sum([X[x][y]-Q[i][j]]) >= 0)
        #constraints.append(cp.sum([X[z][w]-Q[i][j]]) >= 0)
"""
prob = cp.Problem(objective, constraints)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


HBox(children=(FloatProgress(value=0.0, max=1352.0), HTML(value='')))




In [16]:
import time
start = time.time()
result = prob.solve(verbose=True)
print(time.time() - start)
print(objective.value, similarity_to_arguments_component.value, similarity_between_tokens_component.value)

2341.159997701645
0.015026756072061104 0.01532884604809937 -0.3020899760382654


In [25]:
2341/60

39.016666666666666

In [17]:
print("QUERY:", query)
print("===========================================================")
print("===========================================================")

for x,orig_sent,tok2orig in zip(X,orig_sents,mappings_to_orig):
    j = np.argmax(x.value)
    if j in tok2orig:
        print(orig_sent)
        print("ARG1: {}".format(orig_sent.split(" ")[tok2orig[j]]))
    else:
        print("none")
    
    print("---------------------------------------------")

QUERY: infected arg1:patients $usually $develop <>arg2:[entity]respiratory illness.
After an incubation period of 5 to 14 days , SARS-CoV-2-infected people commonly manifest features of pneumonia , including fever , dry cough , dyspnoea , myalgia and fatigue .
ARG1: pneumonia
---------------------------------------------
After an incubation period of 2 to 3 days , patients who have pneumonic plague typically develop fulminant pneumonia , with malaise , high fever , cough , hemoptysis , and septicemia with ecchymoses and extremity necrosis .
ARG1: pneumonia
---------------------------------------------
Patients suffering from severe DENV infection often exhibit encephalopathy and encephalitis .
ARG1: encephalitis
---------------------------------------------
Following an incubation period of usually 4 - 5 days , patients infected with SARS-CoV often present with symptoms of fever , headache , and myalgias .
ARG1: patients
---------------------------------------------
Cats with FIP show 

# Viterbi

In [353]:
from viterbi_trellis import ViterbiTrellis

In [354]:
from viterbi_trellis import ViterbiTrellis
grid = [[(i,j) for j in range(seq_len)] for i in range(num_sents)]
ARG_IND = 0

def state_score_func(sent_ind_tok_ind: tuple, alpha: float = 2,alpha2=1, arg_ind: int = ARG_IND):
    
    sent_ind, tok_ind = sent_ind_tok_ind
    return -alpha*sims_args[arg_ind][sent_ind][tok_ind] + alpha2*sims_args[1 if arg_ind==0 else 0][sent_ind][tok_ind]

def transition_score_func(sent_ind_tok_ind_source, sent_ind_tok_ind_dest):
    
    sent_ind_source, tok_ind_source = sent_ind_tok_ind_source
    sent_ind_dest, tok_ind_dest = sent_ind_tok_ind_dest
        
    return -sims_token[sent_ind_source,tok_ind_source,sent_ind_dest,tok_ind_dest]

def run_multiple_random_hmms(num_sents, n = 10):
    
    sent2captures = defaultdict(list)
    
    for i in range(n):
        
        ordering = list(range(num_sents))
        random.shuffle(ordering)
        grid_ordered = [x for _,x in sorted(zip(ordering,grid))]
        #print(grid_ordered[0][:6])

        new2orig = {i:grid_ordered[i][0][0] for i in range(len(grid_ordered))}
        
        v = ViterbiTrellis(grid_ordered, state_score_func, transition_score_func)
        best_path = v.viterbi_best_path()
        for j, token_ind in enumerate(best_path):
            
            sent_ind = new2orig[j]
            sent2captures[sent_ind].append(token_ind)
    
    return sent2captures

    for sentind in sent2captures.keys():
        counter = Counter(sent2captures[sentind])
        print(counter.most_common(2))
    
v = ViterbiTrellis(grid, state_score_func, transition_score_func)
best_path = v.viterbi_best_path()

v = ViterbiTrellis(grid, state_score_func, lambda x,y: 0)
best_path_naive = v.viterbi_best_path()

In [355]:
sent2captures = run_multiple_random_hmms(num_sents,n=100)

## Eval

In [356]:
print("QUERY:", query)
print("===========================================================")
print("===========================================================")

for sent_ind,j_naive,orig_sent,tok2orig in zip(sorted(sent2captures.keys()),best_path_naive,orig_sents,mappings_to_orig):    
    print(orig_sent)
    for k, (j, freq) in enumerate(Counter(sent2captures[sent_ind]).most_common(4)):
        if j in tok2orig:
              
            print("ARG{}, option {}: {}".format(ARG_IND+1,k, orig_sent.split(" ")[tok2orig[j]]))
        #print("ARG1 NAIVE: {}".format(orig_sent.split(" ")[tok2orig[j_naive]]))
    if j_naive in tok2orig:
         print("ARG{}, naive: {}".format(ARG_IND+1,orig_sent.split(" ")[tok2orig[j_naive]]))
    
    else:
        print("none")
    
    print("---------------------------------------------")

QUERY: arg1:[e]paracetamol is $not useful for treating arg2:[e]asthma.
However , while tocilizumab is a promising agent against COVID-19 , it is not an appropriate agent in patients with active or latent tuberculosis , bacterial and fungal infections , multi-organ failure , and gastrointestinal perforation [ 7 ] .
ARG1, option 0: agent
ARG1, option 1: tocilizumab
ARG1, option 2: it
ARG1, option 3: [
ARG1, naive: tocilizumab
---------------------------------------------
There are cases of ADEM or even fulminant presentation such as AHL where steroids alone are not sufficient for suppressing inflammation and improving clinical findings .
ARG1, option 0: steroids
ARG1, naive: steroids
---------------------------------------------
Thus , ribavirin may not be useful for treating SARS infections because of its questionable efficacy and because of its known toxicity ( reviewed by van Vonderen et al. , 2003 ; Lai , 2005 ) .
ARG1, option 0: ribavirin
ARG1, naive: ribavirin
---------------------