In [1]:
%load_ext autoreload
%autoreload 2
import bert
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import spike_queries
import spacy
from typing import List
from collections import defaultdict
import tqdm

In [2]:
def load_results(fname):
    
    with open(fname, "r", encoding = "utf-8") as f:
        sents = f.readlines()
    
    sents = [s.strip().split("\t")[-1] for s in sents]
    original, results = sents[0], sents[1:]
    return original, results


def get_spike_results_arguments_representations(model, spike_results, layers):
    
    sents = spike_results["sentence_text"].tolist()
    arg1_idx_start = spike_results["arg1_first_index"].to_numpy().astype(int)
    arg2_idx_start = spike_results["arg2_first_index"].to_numpy().astype(int)
    arg1_idx_end = spike_results["arg1_last_index"].to_numpy().astype(int)
    arg2_idx_end = spike_results["arg2_last_index"].to_numpy().astype(int)
    
    arg1_rep = []
    arg2_rep = []
    
    for s, arg1_start, arg2_start, arg1_end, arg2_end in zip(sents, arg1_idx_start, arg2_idx_start, arg1_idx_end, arg2_idx_end):
        #idx_to_mask = [arg1_start, arg2_start, arg1_end, arg2_end]
        H, _, _, orig2tok = model.encode(s, layers = layers)

        h1, h2 = H[orig2tok[arg1_start]:orig2tok[arg1_end] + 1], H[orig2tok[arg2_start]:orig2tok[arg2_end] + 1]
        
        h1 = np.mean(h1, axis = 0)
        h2 = np.mean(h2, axis = 0)
        
        arg1_rep.append(h1)
        arg2_rep.append(h2)
        
    arg1_mean = np.mean(arg1_rep, axis = 0)
    arg2_mean = np.mean(arg2_rep, axis = 0)
        
    return arg1_mean, arg2_mean
    


def main(filename, layers = [-1]):
    
    
    query, results1 = load_results(filename)
    spike_results = spike_queries.perform_query(query, dataset_name = "covid19", num_results = 50, query_type = "syntactic")
    spike_results = spike_results[spike_results['sentence_text'].notnull()]
    arg1_rep, arg2_rep = get_spike_results_arguments_representations(model, spike_results, layers)
        
    first, first_ind1, first_ind2 = spike_results["sentence_text"].tolist()[-1], int(spike_results["arg1_first_index"].tolist()[-1]), int(spike_results["arg2_first_index"].tolist()[-1])
    
    alignments = []
    for s in results1:
        H, tokenized_text, tok_to_orig_map, orig2tok = model.encode(s, layers = layers)
        sims_arg1 = cosine_similarity([arg1_rep], H[1:-2])[0]
        sims_arg2 = cosine_similarity([arg2_rep], H[1:-2])[0]
        arg1_ind = np.argmax(sims_arg1) + 1
        
        sims_arg2[np.argmax(sims_arg1)] = -1
        arg2_ind = np.argmax(sims_arg2) + 1
        #print(arg1_ind, arg2_ind)
        if arg1_ind not in tok_to_orig_map or arg2_ind not in tok_to_orig_map:
            #print("ERROR")
            continue
        alignments.append((s, tok_to_orig_map[arg1_ind], tok_to_orig_map[arg2_ind]))
    
    return alignments

In [3]:
model = bert.BertEncoder("cuda", "scibert")

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d813f7395e7ea533039e02deb1723d8fd9d8ba655391a01a69ad6223d
Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 31090
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivocab_uncased/config.json from cache at /home/shauli/.cache/torch/transformers/199e28e62d2210c23d63625bd9eecc20cf72a156b29e2a540d4933af4f50bda1.4b6b9f5d8

In [20]:
alignments1 = main("results6.txt", layers = [-1])
alignments1 = alignments1[:500]

In [21]:
len(alignments1)

500

## Parse

In [22]:
def parse_results(sentences: List[str]):
    
    #nlp = spacy.load("en_core_sci_sm")
    nlp = spacy.load("en_core_sci_lg")
    
    docs = [spacy.tokens.Doc(vocab=nlp.vocab, words = sentence.split(" ")) for sentence in sentences]
    for name, proc in nlp.pipeline:
        print(name)
        for i,doc in enumerate(tqdm.tqdm(docs)):
            docs[i] = proc(doc)
    return docs



def get_ancestors(tok):
    
    parents = []
    
    while tok.head != tok:
    
        parents.append(tok.head)
        tok = tok.head
        
    
    return parents
    
def get_path_between_tokens(tok1, tok2):
    
    path1 = get_ancestors(tok1)
    path2 = get_ancestors(tok2)
    #print(path1, path2)
    
    lowest_ancestor = None
    for tok in path1:
        if tok in path2:
            lowest_ancestor = tok
            break

    
    path = [tok1] + path1[:path1.index(lowest_ancestor)+1] + path2[::-1][path2.index(lowest_ancestor):] + [tok2]
    return path

In [23]:
sents = [s for s,ind1,ind2 in alignments1]

In [24]:
docs = parse_results(sents)

  9%|▊         | 43/500 [00:00<00:01, 422.52it/s]

tagger


100%|██████████| 500/500 [00:01<00:00, 446.00it/s]
  8%|▊         | 39/500 [00:00<00:01, 384.74it/s]

parser


100%|██████████| 500/500 [00:01<00:00, 386.21it/s]
  9%|▉         | 46/500 [00:00<00:00, 456.34it/s]

ner


100%|██████████| 500/500 [00:01<00:00, 455.34it/s]


In [25]:
docs[0][0].head

injury

In [26]:
get_ancestors(docs[0][6])

[]

In [28]:
w1 = docs[0][3]
w2 = docs[0][6]
print(w1,w2)
get_path_between_tokens(w1, w2)

is manifestation


ValueError: None is not in list

In [29]:
from spacy import displacy
displacy.render(docs[0], style='dep', jupyter=True)

In [34]:
paths_lemmas = []
paths_deps = []
paths_lemmas_deps = []

for doc, (sent, idx1, idx2) in zip(docs, alignments1):
        print("doc is {}".format(doc))
        print("idx are: {} and {}".format(idx1, idx2))

        tok1, tok2 = doc[idx1], doc[idx2]
        path = get_path_between_tokens(tok1, tok2)
    
        print([tok.ent_iob_ for tok in path])
        print([tok.lemma_ for tok in path])
        print([tok.ent_type_ for tok in path])
        paths_lemmas.append("-".join([tok.lemma_ for tok in path]))
        paths_deps.append("-".join([tok.dep_ for tok in path]))
        paths_lemmas_deps.append("-".join([tok.lemma_+"."+tok.dep_ for tok in path]))
        print("------------------------------------")


doc is Acute kidney injury is a possible manifestation in severe forms of COVID-19 , conferring a high mortality 7 . 
idx are: 2 and 8
['I', 'B', 'O', 'B']
['injury', 'manifestation', 'form', 'severe']
['ENTITY', 'ENTITY', '', 'ENTITY']
------------------------------------
doc is Coagulation dysfunction is one of the major causes for death in patients with severe COVID-19 [ 4 ] . 
idx are: 9 and 11
['B', 'O', 'O', 'B', 'B']
['death', 'cause', 'cause', 'death', 'patient']
['ENTITY', '', '', 'ENTITY', 'ENTITY']
------------------------------------
doc is Hypoxia is the most common manifestation of patients with COVID-19 , especially in severely or critically ill patients . 
idx are: 0 and 7
['B', 'O', 'B', 'B']
['hypoxia', 'be', 'manifestation', 'patient']
['ENTITY', '', 'ENTITY', 'ENTITY']
------------------------------------
doc is Early-onset pneumonia is a common and severe complication that is related to aspiration in patients with GBS . 
idx are: 1 and 14
['B', 'O', 'B', 'B', 'B']


ValueError: None is not in list

In [31]:
from collections import Counter
c = Counter(paths_lemmas_deps)
items = list(c.items())
items = sorted(items, key = lambda pair: -pair[1])

In [32]:
items[:50]

[('infection.nmod-patient.nsubj-risk.ROOT-risk.ROOT-patient.nsubj', 5),
 ('infection.nsubj-complication.ROOT-complication.ROOT-patient.nmod', 5),
 ('infection.nsubjpass-associate.ROOT-associate.ROOT-be.auxpass', 4),
 ('injury.nsubj-be.ROOT-condition.attr-patient.nmod', 3),
 ('injury.nsubjpass-report.ROOT-%.nmod-patient.nmod', 3),
 ('be.cop-complication.ROOT-complication.ROOT-patient.nmod', 3),
 ('pneumonia.nsubj-be.ROOT-complication.attr-patient.nmod', 3),
 ('pneumonia.nsubjpass-associate.ROOT-associate.ROOT-be.auxpass', 3),
 ('pneumonia.nsubj-complication.ROOT-complication.ROOT-patient.nmod', 3),
 ('pneumonia.nsubj-cause.ROOT-morbidity.nmod-patient.nmod', 3),
 ('pneumonia.nsubj-be.ROOT-complication.attr-infection.nmod', 2),
 ('AKI.nsubj-occur.ROOT-patient.nmod-disease.nmod', 2),
 ('disease.nmod-associate.acl-mortality.nmod-cause.nmod-one.ROOT-one.ROOT-be.cop',
  2),
 ('be.cop-common.ROOT-common.ROOT-patient.nmod', 2),
 ('pneumonia.nsubj-complication.ROOT-complication.ROOT-child.nmod',