In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

In [2]:
import torch

from wiser.data.dataset_readers import NCBIDiseaseDatasetReader
from wiser.rules import TaggingRule, LinkingRule, UMLSMatcher, DictionaryMatcher
from wiser.generative import get_label_to_ix, get_rules
from labelmodels import *
from wiser.generative import train_generative_model
from labelmodels import LearningConfig
from wiser.generative import evaluate_generative_model
from wiser.data import save_label_distribution
from wiser.eval import *
from wiser.rules import ElmoLinkingRule
from collections import Counter

from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize
from tokenizations import get_alignments, get_original_spans
from typing import List, Optional, Tuple

from transformers import AutoTokenizer, AutoModel

## Construct Weak Annotations Using Allennlp

In [3]:
DATA_PARTITION = 'train'

file_name = ''
if DATA_PARTITION == 'dev':
    file_name = 'NCBIdevelopset_corpus.txt'
elif DATA_PARTITION == 'test':
    file_name = 'NCBItestset_corpus.txt'
elif DATA_PARTITION == 'train':
    file_name = 'NCBItrainset_corpus.txt'

LABEL = 'DISEASE'
LINK = 'ENT'

In [4]:
reader = NCBIDiseaseDatasetReader()
ncbi_docs = reader.read(f'../data/NCBI/{file_name}')

HBox(children=(HTML(value='reading instances'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width…

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6924.0), HTML(value='')))







## Tagging Rules

In [5]:
dict_core = set()
dict_core_exact = set()
with open('../data/AutoNER_dicts/NCBI/dict_core.txt') as f:
    for line in f.readlines():
        line = line.strip().split()
        term = tuple(line[1:])

        if len(term) > 1 or len(term[0]) > 3:
            dict_core.add(term)
        else:
            dict_core_exact.add(term)

# Prepends common modifiers
to_add = set()
for term in dict_core:
    to_add.add(("inherited", ) + term)
    to_add.add(("Inherited", ) + term)
    to_add.add(("hereditary", ) + term)
    to_add.add(("Hereditary", ) + term)

dict_core |= to_add

# Removes common FP
dict_core_exact.remove(("WT1",))
dict_core_exact.remove(("VHL",))


dict_full = set()

with open('../data/AutoNER_dicts/NCBI/dict_full.txt') as f:
    for line in f.readlines():
        line = line.strip().split()
        dict_full.add(tuple(line))


lf = DictionaryMatcher(
    "CoreDictionaryUncased",
    dict_core,
    uncased=True,
    i_label="I")
lf.apply(ncbi_docs)


lf = DictionaryMatcher("CoreDictionaryExact", dict_core_exact, i_label="I")
lf.apply(ncbi_docs)


class CancerLike(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        suffixes = ("edema", "toma", "coma", "noma")

        for i, token in enumerate(tokens):
            for suffix in suffixes:
                if token.endswith(suffix) or token.endswith(suffix + "s"):
                    labels[i] = 'I'
        return labels


lf = CancerLike()
lf.apply(ncbi_docs)


class CommonSuffixes(TaggingRule):

    suffixes = {
        "agia",
        "cardia",
        "trophy",
        "toxic",
        "itis",
        "emia",
        "pathy",
        "plasia"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            for suffix in self.suffixes:
                if instance['tokens'][i].lemma_.endswith(suffix):
                    labels[i] = 'I'
        return labels


lf = CommonSuffixes()
lf.apply(ncbi_docs)


class Deficiency(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        # "___ deficiency"
        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i + 1].lemma_ == 'deficiency':
                labels[i] = 'I'
                labels[i + 1] = 'I'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I'
                    else:
                        break

        # "deficiency of ___"
        for i in range(len(instance['tokens']) - 2):
            if instance['tokens'][i].lemma_ == 'deficiency' and instance['tokens'][i + 1].lemma_ == 'of':
                labels[i] = 'I'
                labels[i + 1] = 'I'
                nnp_active = False
                for j in range(i + 2, len(instance['tokens'])):
                    if instance['tokens'][j].pos_ in ('NOUN', 'PROPN'):
                        if not nnp_active:
                            nnp_active = True
                    elif nnp_active:
                        break
                    labels[j] = 'I'

        return labels


lf = Deficiency()
lf.apply(ncbi_docs)


class Disorder(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'disorder':
                labels[i] = 'I'
                labels[i + 1] = 'I'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I'
                    else:
                        break

        return labels


lf = Disorder()
lf.apply(ncbi_docs)


class Lesion(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'lesion':
                labels[i] = 'I'
                labels[i + 1] = 'I'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I'
                    else:
                        break

        return labels


lf = Lesion()
lf.apply(ncbi_docs)


class Syndrome(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'syndrome':
                labels[i] = 'I'
                labels[i + 1] = 'I'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I'
                    else:
                        break

        return labels


lf = Syndrome()
lf.apply(ncbi_docs)


terms = []
with open('../data/umls/umls_body_part.txt', 'r') as f:
    for line in f.readlines():
        terms.append(line.strip().split(" "))
lf = DictionaryMatcher("TEMP", terms, i_label='TEMP', uncased=True, match_lemmas=True)
lf.apply(ncbi_docs)


class BodyTerms(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        terms = {"cancer", "cancers", "damage", "disease", "diseases", "pain", "injury", "injuries"}

        for i in range(0, len(tokens) - 1):
            if instance['WISER_LABELS']['TEMP'][i] == 'TEMP':
                if tokens[i + 1] in terms:
                    labels[i] = "I"
                    labels[i + 1] = "I"
        return labels


lf = BodyTerms()
lf.apply(ncbi_docs)

for doc in ncbi_docs:
    del doc['WISER_LABELS']['TEMP']


class OtherPOS(TaggingRule):
    other_pos = {"ADP", "ADV", "DET", "VERB"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(0, len(instance['tokens'])):
            if instance['tokens'][i].pos_ in self.other_pos:
                labels[i] = "O"
        return labels


lf = OtherPOS()
lf.apply(ncbi_docs)


stop_words = {"a", "as", "be", "but", "do", "even",
              "for", "from",
              "had", "has", "have", "i", "in", "is", "its", "just",
              "my", "no", "not", "on", "or",
              "that", "the", "these", "this", "those", "to", "very",
              "what", "which", "who", "with"}


class StopWords(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].lemma_ in stop_words:
                labels[i] = 'O'
        return labels


lf = StopWords()
lf.apply(ncbi_docs)


class Punctuation(TaggingRule):

    other_punc = {".", ",", "?", "!", ";", ":", "(", ")",
                  "%", "<", ">", "=", "+", "/", "\\"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].text in self.other_punc:
                labels[i] = 'O'
        return labels


lf = Punctuation()
lf.apply(ncbi_docs)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




## Linking Rules

In [6]:
class PossessivePhrase(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        for i in range(1, len(instance['tokens'])):
            if instance['tokens'][i - 1].text == "'s" or instance['tokens'][i].text == "'s":
                links[i] = 1

        return links


lf = PossessivePhrase()
lf.apply(ncbi_docs)


class HyphenatedPhrase(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        for i in range(1, len(instance['tokens'])):
            if instance['tokens'][i - 1].text == "-" or instance['tokens'][i].text == "-":
                links[i] = 1

        return links


lf = HyphenatedPhrase()
lf.apply(ncbi_docs)


lf = ElmoLinkingRule(.8)
lf.apply(ncbi_docs)


class CommonBigram(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        tokens = [token.text.lower() for token in instance['tokens']]

        bigrams = {}
        for i in range(1, len(tokens)):
            bigram = tokens[i - 1], tokens[i]
            if bigram in bigrams:
                bigrams[bigram] += 1
            else:
                bigrams[bigram] = 1

        for i in range(1, len(tokens)):
            bigram = tokens[i - 1], tokens[i]
            count = bigrams[bigram]
            if count >= 6:
                links[i] = 1

        return links


lf = CommonBigram()
lf.apply(ncbi_docs)


class ExtractedPhrase(LinkingRule):
    def __init__(self, terms):
        self.term_dict = {}

        for term in terms:
            term = [token.lower() for token in term]
            if term[0] not in self.term_dict:
                self.term_dict[term[0]] = []
            self.term_dict[term[0]].append(term)

        # Sorts the terms in decreasing order so that we match the longest
        # first
        for first_token in self.term_dict.keys():
            to_sort = self.term_dict[first_token]
            self.term_dict[first_token] = sorted(
                to_sort, reverse=True, key=lambda x: len(x))

    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        links = [0] * len(instance['tokens'])

        i = 0
        while i < len(tokens):
            if tokens[i] in self.term_dict:
                candidates = self.term_dict[tokens[i]]
                for c in candidates:
                    # Checks whether normalized AllenNLP tokens equal the list
                    # of string tokens defining the term in the dictionary
                    if i + len(c) <= len(tokens):
                        equal = True
                        for j in range(len(c)):
                            if tokens[i + j] != c[j]:
                                equal = False
                                break

                        # If tokens match, labels the instance tokens
                        if equal:
                            for j in range(i + 1, i + len(c)):
                                links[j] = 1
                            i = i + len(c) - 1
                            break
            i += 1

        return links

lf = ExtractedPhrase(dict_full)
lf.apply(ncbi_docs)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=592.0), HTML(value='')))




## Define Functions

In [7]:
def txt_to_token_span(tokens: List[str],
                      text: str,
                      txt_spans: List[tuple]):
    """
    Transfer text-domain spans to token-domain spans
    :param tokens: tokens
    :param text: text
    :param txt_spans: text spans tuples: (start, end, ...)
    :return: a list of transferred span tuples.
    """
    token_indices = get_original_spans(tokens, text)
    tgt_spans = list()
    for txt_span in txt_spans:
        spacy_start = txt_span[0]
        spacy_end = txt_span[1]
        start = None
        end = None
        for i, (s, e) in enumerate(token_indices):
            if s <= spacy_start < e:
                start = i
            if s <= spacy_end <= e:
                end = i + 1
            if (start is not None) and (end is not None):
                break
        assert (start is not None) and (end is not None), ValueError("input spans out of scope")
        tgt_spans.append((start, end))
    return tgt_spans

In [8]:
def respan(src_tokens: List[str],
           tgt_tokens: List[str],
           src_span: List[tuple]):
    """
    transfer original spans to target spans
    :param src_tokens: source tokens
    :param tgt_tokens: target tokens
    :param src_span: a list of span tuples. The first element in the tuple
    should be the start index and the second should be the end index
    :return: a list of transferred span tuples.
    """
    s2t, _ = get_alignments(src_tokens, tgt_tokens)
    tgt_spans = list()
    for spans in src_span:
        start = s2t[spans[0]][0]
        if spans[1] < len(s2t):
            end = s2t[spans[1]-1][-1] + 1
        else:
            end = s2t[-1][-1]
        if end == start:
            end += 1
        tgt_spans.append((start, end))

    return tgt_spans

In [9]:
def label_to_span(labels: List[str],
                  scheme: Optional[str] = 'BIO') -> dict:
    """
    convert labels to spans
    :param labels: a list of labels
    :param scheme: labeling scheme, in ['BIO', 'BILOU'].
    :return: labeled spans, a list of tuples (start_idx, end_idx, label)
    """
    assert scheme in ['BIO', 'BILOU'], ValueError("unknown labeling scheme")

    labeled_spans = dict()
    i = 0
    while i < len(labels):
        if labels[i] == 'O' or labels[i] == 'ABS':
            i += 1
            continue
        else:
            if scheme == 'BIO':
                if labels[i][0] == 'B':
                    start = i
                    lb = labels[i][2:]
                    i += 1
                    try:
                        while labels[i][0] == 'I':
                            i += 1
                        end = i
                        labeled_spans[(start, end)] = lb
                    except IndexError:
                        end = i
                        labeled_spans[(start, end)] = lb
                        i += 1
                # this should not happen
                elif labels[i][0] == 'I':
                    i += 1
            elif scheme == 'BILOU':
                if labels[i][0] == 'U':
                    start = i
                    end = i + 1
                    lb = labels[i][2:]
                    labeled_spans[(start, end)] = lb
                    i += 1
                elif labels[i][0] == 'B':
                    start = i
                    lb = labels[i][2:]
                    i += 1
                    try:
                        while labels[i][0] != 'L':
                            i += 1
                        end = i
                        labeled_spans[(start, end)] = lb
                    except IndexError:
                        end = i
                        labeled_spans[(start, end)] = lb
                        break
                    i += 1
                else:
                    i += 1

    return labeled_spans

In [10]:
def build_bert_emb(sents: List[str],
                   tokenizer,
                   model,
                   device: str):
    bert_embs = list()
    for i, sent in enumerate(sents):

        joint_sent = ' '.join(sent)
        bert_tokens = tokenizer.tokenize(joint_sent)

        input_ids = torch.tensor([tokenizer.encode(joint_sent, add_special_tokens=True)], device=device)
        # calculate BERT last layer embeddings
        with torch.no_grad():
            last_hidden_states = model(input_ids)[0].squeeze(0).to('cpu')
            trunc_hidden_states = last_hidden_states[1:-1, :]

        ori2bert, bert2ori = get_alignments(sent, bert_tokens)

        emb_list = list()
        for idx in ori2bert:
            emb = trunc_hidden_states[idx, :]
            emb_list.append(emb.mean(dim=0))

        # TODO: using the embedding of [CLS] may not be the best idea
        # It does not matter since that embedding is not used in the training
        emb_list = [last_hidden_states[0, :]] + emb_list
        bert_emb = torch.stack(emb_list)
        bert_embs.append(bert_emb)
    return bert_embs

## Manipulate text, true labels and weak labels

In [11]:
with open(f'../data/NCBI/{file_name}', 'r') as f:
    lines = f.readlines()

In [12]:
clusters = list()
clines = None
for l in lines:
    if l == '\n':
        if clines is not None:
            clusters.append(clines)
        clines = list()
    else:
        clines.append(l)

In [13]:
src_token_list = list()
src_anno_list = list()
weak_anno_list = list()
link_anno_list = list()

allen_data = ncbi_docs
mapping_dict = {0:'O', 1:'I'}

for src, allen_annos in zip(clusters, allen_data):
    
    # handle the data read from the source text
    src_txt = src[0].split('|')[2] + src[1].split('|')[2]
    src_tokens = word_tokenize(src_txt)
    for i in range(len(src_tokens)):
        if src_tokens[i] == r'``' or src_tokens[i] == r"''":
            src_tokens[i] = r'"'
    char_spans = list()
    for annos in src[2:]:
        anno_info = annos.strip().split('\t')
        start = int(anno_info[1])
        end = int(anno_info[2])
        char_spans.append((start, end))
    src_spans = txt_to_token_span(src_tokens, src_txt, char_spans)
    
    src_annos = dict()
    for span in src_spans:
        src_annos[span] = LABEL

    src_token_list.append(src_tokens)
    src_anno_list.append(src_annos)
    
    # handle the data constructed using Allennlp
    allen_tokens = list(map(str, allen_annos['tokens']))
    weak_anno = dict()
    
    for k in allen_annos['WISER_LABELS']:
        std_lbs = allen_annos['WISER_LABELS'][k][:]
        
        pre_anno = 'O'
        for i in range(len(std_lbs)):
            current_anno = std_lbs[i]
            if std_lbs[i] == 'I':
                if pre_anno != 'I':
                    std_lbs[i] = 'B-' + LABEL
                else:
                    std_lbs[i] = 'I-' + LABEL
            pre_anno = current_anno
        weak_span = label_to_span(std_lbs)

        src_weak_span = respan(allen_tokens, src_tokens, weak_span)
        src_weak_anno = dict()
        for span in src_weak_span:
            src_weak_anno[span] = [(LABEL, 1.0)]
            
        weak_anno[k] = src_weak_anno
    weak_anno_list.append(weak_anno)

    
    linked_dict = dict()
    for src, entity_lbs in allen_annos['WISER_LINKS'].items():
        entity_lbs = [mapping_dict[lb] for lb in entity_lbs]

        pre_anno = 'O'
        for i in range(len(entity_lbs)):
            current_anno = entity_lbs[i]
            if entity_lbs[i] == 'I':
                if pre_anno != 'I':
                    entity_lbs[i] = 'B-' + LINK
                else:
                    entity_lbs[i] = 'I-' + LINK
            pre_anno = current_anno

        entity_spans = label_to_span(entity_lbs)
        complete_span = dict()
        for (start, end), lb in entity_spans.items():
            if start != 0:
                start = start - 1
            complete_span[(start, end)] = lb
        src_link_span = respan(allen_tokens, src_tokens, complete_span)
        linked_dict[src] = src_link_span
    link_anno_list.append(linked_dict)

In [14]:
updated_link_anno_list = list()
for tag_anno, link_anno in zip(weak_anno_list, link_anno_list):
    tag_spans = list()
    for src, spans in tag_anno.items():
        tag_spans += list(spans.keys())
    for i in range(len(tag_spans)):
        tag_spans[i] = set(range(tag_spans[i][0], tag_spans[i][1]))
    
    link_entities = dict()
    for src, spans in link_anno.items():
        valid_spans = list()
        for span in spans:
            if span[1] - span[0] == 1:
                continue
            span_set = set(range(span[0], span[1]))
            for tag_span in tag_spans:
                if span_set.intersection(tag_span):
                    valid_spans.append(span)
        valid_anno = {span: [(LABEL, 1.0)] for span in valid_spans}
        link_entities[src] = valid_anno
    updated_link_anno_list.append(link_entities)

In [15]:
combined_anno_list = list()
for tag_anno, link_anno in zip(weak_anno_list, updated_link_anno_list):
    comb_anno = dict()
    for k, v in tag_anno.items():
        comb_anno[k] = v
    for k, v in link_anno.items():
        comb_anno[k] = v
    combined_anno_list.append(comb_anno)

## Build Bert Embeddings

In [16]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

In [17]:
standarized_sents = list()
o2n_map = list()
n=0
for i, sents in enumerate(src_token_list):
    joint_sent = ' '.join(sents)
    len_bert_tokens = len(tokenizer.tokenize(joint_sent))
    if len_bert_tokens >= 510:        
        sts = sent_tokenize(joint_sent)
        
        sent_lens = list()
        for st in sts:
            sent_lens.append(len(word_tokenize(st)))
        ends = [np.sum(sent_lens[:i]) for i in range(1, len(sent_lens)+1)]
        
        nearest_end_idx1 = np.argmin((np.array(ends) - len_bert_tokens / 3) ** 2)
        nearest_end_idx2 = np.argmin((np.array(ends) - len_bert_tokens / 3 * 2) ** 2)
        split_1 = sents[:ends[nearest_end_idx1]]
        split_2 = sents[ends[nearest_end_idx1]:ends[nearest_end_idx2]]
        split_3 = sents[ends[nearest_end_idx2]:]
        standarized_sents.append(split_1)
        standarized_sents.append(split_2)
        standarized_sents.append(split_3)
        o2n_map.append([n, n+1, n+2])
        n += 3

    else:
        standarized_sents.append(sents)
        o2n_map.append([n])
        n += 1

In [18]:
for i, sents in enumerate(standarized_sents):
    joint_sent = ' '.join(sents)
    if len(tokenizer.tokenize(joint_sent)) >= 510:
        print(i, len(sents), len(tokenizer.tokenize(joint_sent)))

In [19]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device=torch.device('cpu')
model = model.to(device)
embs = build_bert_emb(standarized_sents, tokenizer, model, device)

In [20]:
combined_embs = list()
for o2n in o2n_map:
    if len(o2n) == 1:
        combined_embs.append(embs[o2n[0]])
    else:
        cat_emb = torch.cat([embs[o2n[0]], embs[o2n[1]][1:], embs[o2n[2]][1:]], dim=0)
        combined_embs.append(cat_emb)
for emb, sent in zip(combined_embs, src_token_list):
    assert len(emb) == len(sent) + 1

## Save Data

In [21]:
data = {
    "sentences": src_token_list,
    "annotations": combined_anno_list,
    "labels": src_anno_list
}

torch.save(data, f"NCBI-linked-{DATA_PARTITION}.pt")
torch.save(combined_embs, f"NCBI-emb-{DATA_PARTITION}.pt")