In [98]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
import spacy
from typing import List
from collections import defaultdict
import tqdm
from termcolor import colored

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [147]:


class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'



In [2]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        #idx_to_mask = [arg1_start, arg2_start, arg1_end, arg2_end]
        H, _, _, orig2tok = model.encode(s, layers = layers)

        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
        
    return arg1_mean, arg2_mean
    


def main(filename, layers = [-1]):
    
    
    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 50, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
        
    first, first_ind1, first_ind2 = spike_results["sentence_text"].tolist()[-1], int(spike_results["arg1_first_index"].tolist()[-1]), int(spike_results["arg2_first_index"].tolist()[-1])
    
    alignments = []
    for s in results1:
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        sims_arg1 = cosine_similarity([arg1_rep], H[1:-2])[0]
        sims_arg2 = cosine_similarity([arg2_rep], H[1:-2])[0]
        arg1_ind = np.argmax(sims_arg1) + 1
        
        sims_arg2[np.argmax(sims_arg1)] = -1
        arg2_ind = np.argmax(sims_arg2) + 1
        #print(arg1_ind, arg2_ind)
        if arg1_ind not in tok_to_orig_map or arg2_ind not in tok_to_orig_map:
            #print("ERROR")
            continue
        alignments.append((s, tok_to_orig_map[arg1_ind], tok_to_orig_map[arg2_ind]))
    
    return alignments

In [3]:
model = bert.BertEncoder("cuda", "scibert")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d8

In [115]:
alignments1 = main("results1.txt", layers = [-1])
alignments1 = alignments1[:500]

In [116]:
len(alignments1)

500

## Parse

In [117]:
def parse_results(sentences: List[str]):
    
    #nlp = spacy.load("en_core_sci_sm")
    nlp = spacy.load("en_core_sci_lg")
    
    docs = [spacy.tokens.Doc(vocab=nlp.vocab, words = sentence.split(" ")) for sentence in sentences]
    for name, proc in nlp.pipeline:
        print(name)
        for i,doc in enumerate(tqdm.tqdm(docs)):
            docs[i] = proc(doc)
    return docs



def get_ancestors(tok):
    
    parents = [tok]
    
    while tok.head != tok:
    
        parents.append(tok.head)
        tok = tok.head
        
    
    return parents
    
def get_path_between_tokens(tok1, tok2):
    
    path1 = get_ancestors(tok1)
    path2 = get_ancestors(tok2)
    #print(path1, path2)
    
    lowest_ancestor = None
    for tok in path1:
        if tok in path2:
            lowest_ancestor = tok
            break

    path1 =  path1[:path1.index(lowest_ancestor)+1]
    path2 =  path2[:path2.index(lowest_ancestor)+1]
    
    linear_path_before = [tok1 if tok1.i < tok2.i else tok2]
    linear_path_after = [tok1 if tok1.i > tok2.i else tok2]
    linear_path_between =  [t for t in doc if t.i < max(tok1.i, tok2.i) and t.i > min(tok1.i, tok2.i)]
    linear_path = linear_path_before + linear_path_between + linear_path_after
    linear_path_str = [t.text for t in linear_path]
    linear_path_str[0] = colored(linear_path_str[0], "red")
    linear_path_str[-1] = colored(linear_path_str[-1], "blue")
    
    linear_path_str = " ".join(linear_path_str)
    
    path = {"ancestor": lowest_ancestor, "path1": path1, "path2": path2, "tok1": tok1, "tok2": tok2, "linear_path": linear_path,
           "linear_path_str": linear_path_str}
    return path

In [119]:
sents = [s for s,ind1,ind2 in alignments1]

In [120]:
docs = parse_results(sents)

  5%|▌         | 27/500 [00:00<00:01, 269.15it/s]

tagger


100%|██████████| 500/500 [00:01<00:00, 339.37it/s]
  6%|▌         | 29/500 [00:00<00:01, 286.58it/s]

parser


100%|██████████| 500/500 [00:01<00:00, 309.31it/s]
  7%|▋         | 35/500 [00:00<00:01, 349.69it/s]

ner


100%|██████████| 500/500 [00:01<00:00, 383.70it/s]


In [126]:
from spacy import displacy
displacy.render(doc, style='dep', jupyter=True)

In [127]:
doc[7].head

ACE2

In [162]:
def regularize(path1: List[str], path2: List[str]):

    #print("path1", path1)
    #print("path2", path2)
    
    
    path1[-1] = color.BOLD + path1[-1] + color.END
    path2[0] = colored(path2[0], "blue")
    path1[0] = colored(path1[0], "red")
        
    path = path1 + path2[::-1][1:]

    return path
        
    

paths_lemmas = []
paths_deps = []
paths_lemmas_deps = []
i = 0

for doc, (sent, idx1, idx2) in zip(docs, alignments1):
    

        tok1, tok2 = doc[idx1], doc[idx2]
        path_dict = get_path_between_tokens(tok1, tok2)
    
        tok1_2ances = path_dict["path1"]
        tok2_2ances = path_dict["path2"]
        
        
        lemmas1 = [tok.lemma_ for tok in tok1_2ances]
        lemmas2 = [tok.lemma_ for tok in tok2_2ances]
        deps1 = [tok.dep_ for tok in tok1_2ances]
        deps2 = [tok.dep_ for tok in tok2_2ances]
        
        lemmas = regularize(lemmas1, lemmas2)
        deps = regularize(deps1, deps2)
        zipped = "-".join([lemma+"."+dep for lemma, dep in zip(lemmas, deps)])
        paths_lemmas_deps.append(zipped)
        paths_lemmas.append("-".join(lemmas))
        paths_deps.append("-".join(deps))
        
        i += 1


In [170]:
from collections import Counter
c = Counter(paths_lemmas)
items = list(c.items())
items = sorted(items, key = lambda pair: -pair[1])
print(items[0][0])

[31minfection[0m-[1mcause[0m-[34mdisease[0m


In [171]:
for i in range(50):
    print(items[i][0], items[i][1])

[31minfection[0m-[1mcause[0m-[34mdisease[0m 5
[31minfection[0m-[1mcharacterize[0m-[34mbe[0m 4
[31minfection[0m-[1mcause[0m 3
[31minfection[0m-[1mcharacterize[0m-[34minflammation[0m 3
[31minfection[0m-[1mresult[0m-[34mdisease[0m 3
[31minfection[0m-[1massociate[0m-[34mbe[0m 3
[31mibv[0m-[1minfection[0m 3
[31mvirus[0m-spread-[1mcharacterize[0m-[34mbe[0m 2
[31mpedv[0m-infection-[1mcause[0m-[34menteritis[0m 2
[31mhrsv[0m-[1minfection[0m 2
[31mprv[0m-[1minfection[0m 2
[31m[1mdisease[0m[0m-result-[34mdamage[0m 2
[31mbkv[0m-[1mestablish[0m-[34minfection[0m 2
[31mlasv[0m-[1minfection[0m 2
[31mpedv[0m-infection-[1mcharacterize[0m-[34mbe[0m 2
[31m[1mpneumonia[0m[0m-[34msevere[0m 2
[31m[1menteritis[0m[0m-[34macute[0m 2
[31mhev[0m-[1minfection[0m 2
[31minfection[0m-[1mcause[0m-spectrum-[34mdisease[0m 2
[31mpneumonia[0m-[1mcause[0m-[34mhowever[0m 2
[31minfection[0m-[1mresult[0m-loss-[34mf

In [137]:
doc

While many cases are mild , in some , SARS-CoV-2 is able to infect lower respiratory epithelial cells and induce a pathogenic immune response that can ultimately lead to hypoxic respiratory failure , acute respiratory distress syndrome ( ARDS ) , and death . 