In [1]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
from termcolor import colored

In [490]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        #idx_to_mask = [arg1_start, arg2_start, arg1_end, arg2_end]
        H, _, _, orig2tok = model.encode(s, layers = layers)

        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
        
    return arg1_mean, arg2_mean
    

def print_sentence_nicely(sentence: str, ind1, ind2):
    
    arg1_sign = "**"
    arg2_sign = "++"
    arg1_color, arg2_color = "red", "blue"
    
    if not ind2 > ind1: 
        ind1, ind2 = ind2, ind1
        arg1_sign, arg2_sign = arg2_sign, arg1_sign
        arg1_color, arg2_color = arg2_color, arg1_color
        
    splitted = sentence.split(" ")
    before_arg1 = " ".join(splitted[:ind1])
    arg1 = splitted[ind1]
    arg2 = splitted[ind2]
    between = " ".join(splitted[ind1 + 1: ind2])
    suffix = " ".join(splitted[ind2+1:])
    
    return before_arg1 +  " " + arg1_sign + colored(arg1, arg1_color) + arg1_sign + " " + between + " " + arg2_sign + colored(arg2, arg2_color) + arg2_sign + " " + suffix



def main(filename, layers = [-1], num_results_to_print = 6):
    
    
    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 100, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
    
    #print(color.BOLD + "QUERY" + color.END + ":\n{}".format(query))
    
    first, first_ind1, first_ind2 = spike_results["sentence_text"].tolist()[-1], int(spike_results["arg1_first_index"].tolist()[-1]), int(spike_results["arg2_first_index"].tolist()[-1])
    #print(color.BOLD + "\nFIRST SPIKE RESULT" + color.END + ":\n{}".format(print_sentence_nicely(first, first_ind1, first_ind2)))
    #print(color.BOLD + "\nAUGMENTATION RESULTS:\n" + color.END)
    
    representations = []
    mappings_to_orig = []
    mappings_to_tok = []
    tokenized_txts = []
    orig_sents = []
    
    for i,s in enumerate(results1):
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        orig_sents.append(s)
        representations.append(H)
        mappings_to_orig.append(tok_to_orig_map)
        mappings_to_tok.append(orig2tok) 
        tokenized_txts.append(tokenized_text)
        
        if i > num_results_to_print: break
    
    return (arg1_rep, arg2_rep), (representations, mappings_to_orig, mappings_to_tok, tokenized_txts, orig_sents)

In [3]:
model = bert.BertEncoder("cuda", "scibert")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d8

In [491]:
def get_between_tokens_similarity(padded_representations):
    
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))
    sims = cosine_similarity(padded_representations, padded_representations)
    #sims = sims.reshape((num_sents, seq_len, num_sents, seq_len))
    
    return sims

def get_similarity_to_arguments(padded_representations, arg1, arg2):
    num_sents, seq_len, bert_dim = padded_representations.shape
    padded_representations = padded_representations.reshape((num_sents*seq_len, bert_dim))    
    sims = cosine_similarity([arg1_rep, arg2_rep], padded_representations)
    sims = sims.reshape((2, num_sents, seq_len))
    return sims

In [492]:
(arg1_rep, arg2_rep), (representations, mappings_to_orig, mappings_to_tok, tokenized_txts, orig_sents)  = main("results2.txt", layers = [-1])
for i in range(len(representations)): # zero cls, ., sep
    representations[i][0][:] = 0.0
    representations[i][-1][:] = 0.0
    representations[i][-2][:] = 0.0

In [493]:
pad_width = max([len(s) for s in representations])
padded_representations = np.array([np.concatenate([r, np.zeros((pad_width-len(r), 768))]) for r in representations])
num_sents, seq_len, bert_dim = padded_representations.shape
num_tokens = num_sents * seq_len
sims_token = get_between_tokens_similarity(padded_representations)
sims_args = get_similarity_to_arguments(padded_representations, arg1_rep, arg2_rep)
"""
for i in range(num_sents):
        sims_args[0][i][0] = 0.0
        sims_args[1][i][0] = 0.0
        sims_args[0][i][-1] = 0.0
        sims_args[1][i][-1] = 0.0
        sims_args[0][i][-2] = 0.0
        sims_args[1][i][-2] = 0.0
"""         

'\nfor i in range(num_sents):\n        sims_args[0][i][0] = 0.0\n        sims_args[1][i][0] = 0.0\n        sims_args[0][i][-1] = 0.0\n        sims_args[1][i][-1] = 0.0\n        sims_args[0][i][-2] = 0.0\n        sims_args[1][i][-2] = 0.0\n'

In [494]:
sims_token.shape

(416, 416)

In [495]:
from collections import defaultdict

In [496]:
import cvxpy as cp
import numpy as np
from collections import defaultdict
import tqdm

num_variables = sims_args.shape[1]
#x = [[cp.Variable(1, boolean=True) for j in range(seq_len)] for i in range(num_sents)]
X = cp.Variable((num_sents, seq_len), boolean = True)
#Q = defaultdict(dict)

#for i in range(num_sents):
#    for j in range(seq_len):
#        Q[i][j] = cp.Variable((num_sents, seq_len), boolean = True)
Q = cp.Variable(sims_token.shape)

similarity_to_arguments_component = cp.sum( cp.multiply(X, sims_args[0]))/(X.shape[0]*X.shape[1])
similarity_between_tokens_component = cp.sum(cp.multiply(Q, sims_token))/(sims_token.shape[0]*sims_token.shape[1])
#similarity_between_tokens_component = [ [[ [Q[i][j][k][l]*sims_token[i][j][k][l] for j in range(seq_len)] for  i 
#                                              in range(num_sents)] for l in range(seq_len)] for k in tqdm.tqdm(range(num_sents))]

In [None]:
objective = cp.Maximize(similarity_to_arguments_component + 0.25*similarity_between_tokens_component)
#objective = cp.Maximize(cp.sum(X * sims_args[0]))
constraints = []

for i in range(num_sents):
    constraints.append(cp.sum(X[i]) <= 1)
    constraints.append(cp.sum(X[i]) >= 1)
    
    
for i in tqdm.tqdm_notebook(range(sims_token.shape[0])):
    y = i % seq_len # row ind
    x = i // seq_len  # col ind
    
    for j in range(sims_token.shape[0]):
        w = j % seq_len # second row ind
        z = j // seq_len  # second col ind
        constraints.append(cp.sum([-X[x][y]-X[z][w]+Q[i][j]]) >= -1)
        constraints.append(cp.sum([X[x][y]-Q[i][j]]) >= 0)
        constraints.append(cp.sum([X[z][w]-Q[i][j]]) >= 0)

prob = cp.Problem(objective, constraints)

In [489]:
import time
start = time.time()
result = prob.solve(verbose=True)
print(time.time() - start)

KeyboardInterrupt: 

In [439]:

for i in range(num_sents):
    constraints.append(cp.sum(X[i]) <= 1)
    constraints.append(cp.sum(X[i]) >= 1)
    
    for j in range(seq_len):
        
        for i2 in range(num_sents):
            
            for j2 in range(seq_len):
                
                q = cp.Variable()

In [440]:
sims_token.shape, X.shape 

((260, 260), (5, 52))

In [441]:
27*52*27*52

1971216

In [442]:
for x,orig_sent,tok2orig in zip(X,orig_sents,mappings_to_orig):
    j = np.argmax(x.value)
    if j in tok2orig:
        print(orig_sent)
        print("Argument: {}".format(orig_sent.split(" ")[tok2orig[j]]))
    else:
        print("none")
    
    print("---------------------------------------------")

After an incubation period of 5 to 14 days , SARS-CoV-2-infected people commonly manifest features of pneumonia , including fever , dry cough , dyspnoea , myalgia and fatigue .
Argument: people
---------------------------------------------
After an incubation period of 2 to 3 days , patients who have pneumonic plague typically develop fulminant pneumonia , with malaise , high fever , cough , hemoptysis , and septicemia with ecchymoses and extremity necrosis .
Argument: patients
---------------------------------------------
Patients suffering from severe DENV infection often exhibit encephalopathy and encephalitis .
Argument: Patients
---------------------------------------------
Following an incubation period of usually 4 - 5 days , patients infected with SARS-CoV often present with symptoms of fever , headache , and myalgias .
Argument: patients
---------------------------------------------
Cats with FIP show nonspecific clinical signs such as fever , weight loss and anorexia , often 

In [286]:
sims_token.shape

(27, 52, 27, 52)

In [287]:
27*52*27*52

1971216

In [299]:
xval = X.value.astype(int)

In [300]:
xval

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       ...,
       [0, 1, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [317]:
sims_token[xval,None].shape

(27, 52, 1, 52, 27, 52)

In [334]:
data = np.ma.array(sims_token, mask=sims_token*xval[:,:,None,None])

In [335]:
data.shape

(27, 52, 27, 52)

In [336]:
xval[:,:,None,None].shape

(27, 52, 1, 1)