In [3]:
import json
import spacy
import re

spacy_model = 'en_core_sci_md'
nlp = spacy.load(spacy_model)

In [15]:
filename =  "data/BioRED/Train.PubTator"

In [79]:
class Document:
    
    def __init__(self, id):
        self.id = id
        self.text_instances = []

class PubtatorDocument(Document):
    
    def __init__(self, id):
        
        super().__init__(id)
        self.relation_pairs = None
        self.nary_relations = None
        self.variant_gene_pairs = {}

class TextInstance:
    
    def __init__(self, text):
        self.text = text
        self.annotations = []
        self.offset = 0
                        
        self.tokenized_text = ''
        self.pos_tags = []
        self.head = []
        self.head_indexes = []
        self.stems = []

class AnnotationInfo:
    
    def __init__(self, position, length, text, ne_type):
        self.position = position
        self.length = length
        self.text = text
        self.ne_type = ne_type
        self.ids = set()
        self.corresponding_gene_id = ''
        self.orig_ne_type = '' # sometime ne_type is normalized eg "Variant" => 'Gene'
        self.corresponding_variant_ids = set()

In [9]:
re_id_spliter_str = r'\,'  
normalized_type_dict = {'SequenceVariant':'GeneOrGeneProduct'}
has_novelty = True
src_tgt_pairs = set(
        [('ChemicalEntity', 'ChemicalEntity'),
            ('ChemicalEntity', 'DiseaseOrPhenotypicFeature'),
            ('ChemicalEntity', 'GeneOrGeneProduct'),
            ('DiseaseOrPhenotypicFeature', 'GeneOrGeneProduct'),
            ('GeneOrGeneProduct', 'GeneOrGeneProduct')])

In [25]:
def add_annotations_2_text_instances(text_instances, annotations):
    offset = 0
    for text_instance in text_instances:
        text_instance.offset = offset
        offset += len(text_instance.text) + 1
        
    for annotation in annotations:
        can_be_mapped_to_text_instance = False
                
        for i, text_instance in enumerate(text_instances):
            if text_instance.offset <= annotation.position and annotation.position + annotation.length <= text_instance.offset + len(text_instance.text):
                
                annotation.position = annotation.position - text_instance.offset
                text_instance.annotations.append(annotation)
                can_be_mapped_to_text_instance = True
                break
        if not can_be_mapped_to_text_instance:
            print(annotation.text)
            print(annotation.position)
            print(annotation.length)
            print(annotation, 'cannot be mapped to original text')
            raise

In [175]:
# Extract the info from the PubTator file
# all_documents = load_pubtator_into_documents(in_pubtator_file, 
                                                #  normalized_type_dict = normalized_type_dict,
                                                #  re_id_spliter_str = re_id_spliter_str)

def make_doc():
    c = 0
    count = 20000
    # pmid_2_index_2_groupID_dict is a dict to classify the pmids into groups
    pmid_2_index_2_groupID_dict = None

    documents = []
    with open(filename, 'r', encoding='utf8') as pub_reader:
        pmid = ''
            
        document = None
        
        annotations = []
        text_instances = []
        relation_pairs = {}
        index2normalized_id = {}
        id2index = {}

        
        for line in pub_reader:
            c += 1
            line = line.rstrip()

            if line == '':
                # print("line == ''")

                document = PubtatorDocument(pmid)
                add_annotations_2_text_instances(text_instances, annotations)
                document.text_instances = text_instances
                document.relation_pairs = relation_pairs
                documents.append(document)

                annotations = []
                text_instances = []
                relation_pairs = {}
                index2normalized_id = {}
                id2index = {}
                continue

            tks = line.split('|')
            if len(tks) > 1 and (tks[1] == 't' or tks[1] == 'a'):
                #2234245	250	270	audiovisual toxicity	Disease	D014786|D006311
                pmid = tks[0]
                x = TextInstance(tks[2])
                text_instances.append(x)

            else:
                _tks = line.split('\t')
                if len(_tks) == 6:
                    start = int(_tks[1])
                    end = int(_tks[2])
                    index = _tks[1] + '|' + _tks[2]
                    text = _tks[3]
                    ne_type = _tks[4]
                    ne_type = re.sub('\s*\(.*?\)\s*$', '', ne_type)
                    orig_ne_type = ne_type
                    if ne_type in normalized_type_dict:
                        ne_type = normalized_type_dict[ne_type]
                    
                    _anno = AnnotationInfo(start, end-start, text, ne_type)
                    
                    #2234245	250	270	audiovisual toxicity	Disease	D014786|D006311
                    ids = [x.strip('*') for x in re.split(re_id_spliter_str, _tks[5])]
                    
                    # if annotation has groupID then update its id
                    if orig_ne_type == 'SequenceVariant':
                        if pmid_2_index_2_groupID_dict != None and index in pmid_2_index_2_groupID_dict[pmid]:
                            print("pmid_2_index_2_groupID_dict != None")

                            index2normalized_id[index] = pmid_2_index_2_groupID_dict[pmid][index][0] # pmid_2_tmvarID_2_groupID_dict[pmid][_id] => (var_id, gene_id)
                            _anno.corresponding_gene_id = pmid_2_index_2_groupID_dict[pmid][index][1]
                    for i, _id in enumerate(ids):
                        if pmid_2_index_2_groupID_dict != None and index in pmid_2_index_2_groupID_dict[pmid]:
                            id2index[ids[i]] = index
                            ids[i] = pmid_2_index_2_groupID_dict[pmid][index][0] # pmid_2_tmvarID_2_groupID_dict[pmid][_id] => (var_id, gene_id)
                            _anno.corresponding_gene_id = pmid_2_index_2_groupID_dict[pmid][index][1]
                        else:
                            #ids[i] = re.sub('\s*\(.*?\)\s*$', '', _id)
                            ids[i] = _id

                    _anno.orig_ne_type = orig_ne_type
                    _anno.ids = set(ids)
                    annotations.append(_anno)

                # ==5 is the line of relation
                elif len(_tks) == 4 or len(_tks) == 5:
                    id1 = _tks[2]
                    id2 = _tks[3]

                    if pmid_2_index_2_groupID_dict != None and (id1 in id2index) and (id2index[id1] in index2normalized_id):
                        id1 = index2normalized_id[id2index[id1]] # pmid_2_tmvarID_2_groupID_dict[pmid][_id] => (var_id, gene_id)
                    if pmid_2_index_2_groupID_dict != None and (id2 in id2index) and (id2index[id2] in index2normalized_id):
                        id2 = index2normalized_id[id2index[id2]] # pmid_2_tmvarID_2_groupID_dict[pmid][_id] => (var_id, gene_id)
                    rel_type = _tks[1]
                    if len(_tks) == 5:
                        rel_type += '|' + _tks[-1]
                    relation_pairs[(id1, id2)] = rel_type

            if count == c:
                break

        if len(text_instances) != 0:
            document = PubtatorDocument(pmid)
            add_annotations_2_text_instances(text_instances, annotations)
            document.text_instances = text_instances
            document.relation_pairs = relation_pairs
            documents.append(document)


    print(len(documents))
    return documents

In [127]:
def _spacy_split_sentence(text, nlp):
    offset = 0
    offsets = []
    doc = nlp(text)
    
    do_not_split = False
    start = 0
    end = 0
    for sent in doc.sents:
        # look for a sentence ending with a lower char followed by a dot, an upper char followed by a space,
        # and then a greater char, a word 'del' followed by a dot, and a word 'viz' followed by a dot

        # if finid, do not split the sentence, otherwise, split the sentence
        if re.search(r'\b[a-z]\.$|[A-Z] ?\>$|[^a-z]del\.$| viz\.$', sent.text):
            if not do_not_split:
                start = offset
            end = offset + len(sent.text)
            offset = end
            for c in text[end:]:
                if c == ' ':
                    offset += 1
                else:
                    break
            do_not_split = True
        else:
            if do_not_split:                
                do_not_split = False
                end = offset + len(sent.text)
                offset = end
                for c in text[end:]:
                    if c == ' ':
                        offset += 1
                    else:
                        break
                offsets.append((start, end))
            else:
                start = offset
                end = offset + len(sent.text)
                offsets.append((start, end))
                
                offset = end
                for c in text[end:]:
                    if c == ' ':
                        offset += 1
                    else:
                        break
        
    if do_not_split:
        offsets.append((start, end))
    # there must be a space at the end, no space at the beginning
    return offsets


def split_sentence(document, nlp):
    new_text_instances = []
    for text_instance in document.text_instances:
        # split into sentences
        offsets = [o for o in _spacy_split_sentence(text_instance.text, nlp)]
        #offsets = [o for o in _nltk_split_sentence(text_instance.text)]
        _tmp_text_instances = []
        for start, end in offsets:
            new_text_instance = TextInstance(text_instance.text[start:end])
            new_text_instance.offset = start
            _tmp_text_instances.append(new_text_instance)
        for annotation in text_instance.annotations:
            is_entity_splited = True
            for _tmp_text_instance in _tmp_text_instances:
                if _tmp_text_instance.offset <= annotation.position and \
                    (annotation.position + annotation.length) - _tmp_text_instance.offset <= len(_tmp_text_instance.text):
                    annotation.position = annotation.position - _tmp_text_instance.offset
                    _tmp_text_instance.annotations.append(annotation)
                    is_entity_splited = False
                    break
            if is_entity_splited:
                print(annotation.position, annotation.length, annotation.text)
                print (' splited by Spacy\' sentence spliter is failed to be loaded into TextInstance\n')
                for _tmp_text_instance in _tmp_text_instances:
                    print (_tmp_text_instance.offset, len(_tmp_text_instance.text), _tmp_text_instance.text)
        new_text_instances.extend(_tmp_text_instances)
    
    document.text_instances = new_text_instances

In [289]:
documents = make_doc()


# tokenize_documents_by_spacy(all_documents, spacy_model)
for document in documents:
    # split into sentences, with the entities, respectively
    split_sentence(document, nlp)

    # tokenize_document_by_spacy(document, nlp)
    for text_instance in document.text_instances: 
        # remove multi spaces
        doc = nlp(re.sub(r'\s+', ' ', text_instance.text))
                
        tokens = []
        for i, token in enumerate(doc):
            # token -> Spacy.tokens.token.Token
            tokens.append(token.text)
            # Coarse-grained part-of-speech from the Universal POS tag set. -> str
            text_instance.pos_tags.append(token.pos_)
            # Syntactic dependency relation. -> str
            text_instance.head.append(token.dep_)
            # The syntactic parent, or “governor”, of this token. -> Token -> int
            text_instance.head_indexes.append(token.head.i)
            # Base form of the token, with no inflectional suffixes.
            text_instance.stems.append(token.lemma_)

        # punctuation with spaces before and after
        text_instance.tokenized_text = ' '.join(tokens)


    # break
    

400


In [290]:
all_documents = documents
print('=======>len(all_documents)', len(all_documents))



In [224]:
def enumerate_all_id_pairs_by_specified(document,
                                        src_tgt_pairs,
                                        only_pair_in_same_sent):
    all_pairs = set()
    
    if only_pair_in_same_sent:
        for text_instance in document.text_instances:
                    
            all_id_infos_list = list()
            _all_id_infos_set = set()
            
            text_instance.annotations = sorted(text_instance.annotations, key=lambda x: x.position, reverse=False)
            for annotation in text_instance.annotations:
                for id in annotation.ids:
                    if (id, annotation.ne_type) not in _all_id_infos_set:
                        all_id_infos_list.append((id, annotation.ne_type))
                        _all_id_infos_set.add((id, annotation.ne_type))
            
            #print('====>len(all_id_infos_list)', len(all_id_infos_list))
            for i in range(0, len(all_id_infos_list) - 1):
                id1_info = all_id_infos_list[i]
                for j in range(i + 1, len(all_id_infos_list)):
                    id2_info = all_id_infos_list[j]
                    #print(id1_info[0], id2_info[1], id1_info[1], id2_info[1])
                    for src_ne_type, tgt_ne_type in src_tgt_pairs:
                        if id1_info[1] == src_ne_type and id2_info[1] == tgt_ne_type:
                            all_pairs.add((id1_info[0], id2_info[0], id1_info[1], id2_info[1]))
                            break
                            #print('OK')
                        elif id2_info[1] == src_ne_type and id1_info[1] == tgt_ne_type:
                            all_pairs.add((id2_info[0], id1_info[0], id2_info[1], id1_info[1]))
                            break
                        #print('OK')
    else:    
        all_id_infos_list = list()
        _all_id_infos_set = set()
        
        for text_instance in document.text_instances:
            text_instance.annotations = sorted(text_instance.annotations, key=lambda x: x.position, reverse=False)
            for annotation in text_instance.annotations:
                for id in annotation.ids:
                    if (id, annotation.ne_type) not in _all_id_infos_set:
                        all_id_infos_list.append((id, annotation.ne_type))
                        _all_id_infos_set.add((id, annotation.ne_type))
        
        #print('====>len(all_id_infos_list)', len(all_id_infos_list))
        for i in range(0, len(all_id_infos_list) - 1):
            id1_info = all_id_infos_list[i]
            for j in range(i + 1, len(all_id_infos_list)):
                id2_info = all_id_infos_list[j]
                #print(id1_info[0], id2_info[1], id1_info[1], id2_info[1])
                for src_ne_type, tgt_ne_type in src_tgt_pairs:
                    if id1_info[1] == src_ne_type and id2_info[1] == tgt_ne_type:
                        all_pairs.add((id1_info[0], id2_info[0], id1_info[1], id2_info[1]))
                        break
                        #print('OK')
                    elif id2_info[1] == src_ne_type and id1_info[1] == tgt_ne_type:
                        all_pairs.add((id2_info[0], id1_info[0], id2_info[1], id1_info[1]))
                        break
                        #print('OK')
    return all_pairs

In [248]:
def convert_text_instance_2_iob2(text_instance, id1, id2, do_mask_other_nes):
    tokens = []
    labels = []
    
    for token in text_instance.tokenized_text.split(' '):
        tokens.append(token)
        labels.append('O')
        
    annotation_indexes_wo_count_space = []
    for annotation in text_instance.annotations:
        start = len(text_instance.text[:annotation.position].replace(' ', ''))
        end = start + len(annotation.text.replace(' ', ''))
        annotation_indexes_wo_count_space.append((start, end))
    
    for (start, end), annotation in zip(annotation_indexes_wo_count_space, text_instance.annotations):
        offset = 0
        for i, token in enumerate(tokens):
            if offset == start:
                if id1 in annotation.ids:
                    labels[i] = "B-" + annotation.ne_type + 'Src'
                elif id2 in annotation.ids:
                    labels[i] = "B-" + annotation.ne_type + 'Tgt'
                elif do_mask_other_nes:
                    labels[i] = "B-" + annotation.ne_type
            elif start < offset and offset < end:
                if id1 in annotation.ids:
                    labels[i] = "I-" + annotation.ne_type + 'Src'
                elif id2 in annotation.ids:
                    labels[i] = "I-" + annotation.ne_type + 'Tgt'
                elif do_mask_other_nes:
                    labels[i] = "I-" + annotation.ne_type
            elif offset < start and start < offset + len(token): #ex: renin-@angiotensin$
                if id1 in annotation.ids:
                    labels[i] = "B-" + annotation.ne_type + 'Src'
                elif id2 in annotation.ids:
                    labels[i] = "B-" + annotation.ne_type + 'Tgt'
                elif do_mask_other_nes:
                    labels[i] = "B-" + annotation.ne_type                
                    
            offset += len(token)
        
    return tokens, labels

In [272]:
def shift_neighbor_indices_and_add_end_tag(tagged_sent,
                                           ne_positions,
                                           ne_list, 
                                           neighbor_indices,
                                           has_end_tag):
                    
    new_tagged_sent = tagged_sent.split(' ')
    
    # if parsing sentence fail => len(neighbor_indices) == 0
    if len(neighbor_indices) > 0:
        # update indices by using ne_positions if indices > ne_positions then shift NE's length
        for _neighbor_indices in neighbor_indices:
            
            for i, _indice in enumerate(_neighbor_indices):
                
                if not has_end_tag:
                    _shift_num = 0
                else:
                    # we consider "end tag" as part of ne text
                    _shift_num = 1
                for j, shift_point_indice in enumerate(ne_positions):
                    if _indice > shift_point_indice:
                        _shift_num += len(ne_list[j].split(' '))
                _neighbor_indices[i] += _shift_num
            
    ne_positions.reverse()
    ne_list.reverse()
        
    for ne_position, ne_text in zip(ne_positions, ne_list):
        
        if len(neighbor_indices) > 0:
            ne_tag_neighbor_indices = neighbor_indices[ne_position]
        
        # add ne into neighbor and tagged sent
        for i, _ne_token in enumerate(ne_text.split(' ')):
            
            if len(neighbor_indices) > 0:
                # ne text point to ne tag
                neighbor_indices.insert(ne_position + i, [ne_position])
                neighbor_indices[ne_position + i] += ne_tag_neighbor_indices
            
            # insert ne text
            new_tagged_sent.insert(ne_position + 1 + i, _ne_token)
        
        
        if has_end_tag:
            # ne text point to ne tag
            end_tag_index = ne_position + len(ne_text.split(' ')) + 1
            if len(neighbor_indices) > 0:
                neighbor_indices.insert(end_tag_index, [ne_position])
                neighbor_indices[end_tag_index] += ne_tag_neighbor_indices
            new_tagged_sent.insert(end_tag_index, new_tagged_sent[ne_position].replace('@', '@/'))
                        
    return ' '.join(new_tagged_sent)

In [270]:
def convert_iob2_to_tagged_sent(
        tokens, 
        labels, 
        in_neighbors_list,
        token_offset,
        to_mask_src_and_tgt = False,
        has_end_tag = False):
    
        
    num_orig_tokens = len(tokens)
    
    previous_label = 'O'
    
    orig_token_index_2_new_token_index_mapping = []
    
    current_idx = -1
    
    tagged_sent = ''
    ne_type = ''
    ne_text = ''
    ne_list = []
    # convert IOB2 to bert format
    # NEs are replaced by tags
    for i, (token, label) in enumerate(zip(tokens, labels)):
        if label == 'O':
            if previous_label != 'O':
                tagged_sent += '$ ' + token
                #print('1 ne_list.append(ne_text)', ne_text)
                ne_list.append(ne_text)
                ne_text = ''
            else:
                tagged_sent += ' ' + token    
            current_idx += 1
                
        elif label.startswith('B-'):
            if previous_label != 'O':
                tagged_sent += '$ @' + label[2:]
                #print('2 ne_list.append(ne_text)', ne_text)
                ne_list.append(ne_text)
                ne_text = token
                ne_type = label[2:]
            else:
                tagged_sent += ' @' + label[2:]
                ne_text = token
                ne_type = label[2:]
            current_idx += 1
                
        elif label.startswith('I-'):
            ne_text += ' ' + token
        #print('=================>')
        #print(i, token, label)
        #print(tagged_sent)
        previous_label = label
        orig_token_index_2_new_token_index_mapping.append(current_idx)
    if ne_text != '':
        ne_list.append(ne_text)
        ne_text = ''
    tagged_sent = tagged_sent.strip()
    if previous_label != 'O':
        tagged_sent += '$'
            
    #    
    
    # update neighbor index
    previous_idx = 0
    _new_neighbors = [] # 
    
        
    _tokens = tagged_sent.split(' ')
    
    ne_positions = []
    for i in range(len(_tokens)):
        token = _tokens[i]
        if token.startswith('@') and token.endswith('$'):
            ne_positions.append(i)
            
    new_in_neighbors_list = []
    if len(in_neighbors_list) != 0:
        # update in_neighbors_list to new_in_neighbors_list by orig_token_index_2_new_token_index_mapping
        for i in range(num_orig_tokens):
            if previous_idx == orig_token_index_2_new_token_index_mapping[i]:
                for neighbor_idx in in_neighbors_list[i]:
                    _new_neighbors.append(orig_token_index_2_new_token_index_mapping[neighbor_idx])
            else:
                new_in_neighbors_list.append(list(set(_new_neighbors)))
                _new_neighbors = []
                for neighbor_idx in in_neighbors_list[i]:
                    _new_neighbors.append(orig_token_index_2_new_token_index_mapping[neighbor_idx])
            previous_idx = orig_token_index_2_new_token_index_mapping[i]
        new_in_neighbors_list.append(list(set(_new_neighbors)))
    #
    
    # insert ne text and update neighbor index again
    if to_mask_src_and_tgt == False:
        tagged_sent = shift_neighbor_indices_and_add_end_tag(
                               tagged_sent,
                               ne_positions,
                               ne_list,
                               new_in_neighbors_list,
                               has_end_tag)
    #

    # add token_offset to neighbor index
    new_in_neighbors_list = ['|'.join([str(i + token_offset) for i in set(neighbors)]) for neighbors in new_in_neighbors_list]
    
    
    return tagged_sent.strip(),\
           ' '.join(new_in_neighbors_list),\
           token_offset + len(new_in_neighbors_list)

In [291]:
out_bert_file = out_tsv_file = "data/BioRED/t" + 'train.tsv'
only_pair_in_same_sent = False
neg_label = 'None'
pos_label = ''
do_mask_other_nes = False
to_mask_src_and_tgt = False
has_end_tag = True
# dump_documents_2_bert_format(
#     all_documents = all_documents, 
#     out_bert_file = out_tsv_file, 
#     src_tgt_pairs = src_tgt_pairs,
#     has_novelty = has_novelty)

num_seq_lens = []

_index = 0

with open(out_bert_file, 'w', encoding='utf8') as bert_writer:
            
    number_unique_YES_instances = 0
    for document in all_documents:
        pmid = document.id
        # if only extract the pairs within one sentence
        # (identifier1, identifier2, ne_type1, ne_type2)
        all_pairs = enumerate_all_id_pairs_by_specified(document,
                                                src_tgt_pairs,
                                                only_pair_in_same_sent)
        unique_YES_instances = set()

        # print('===============>document.relation_pairs', document.relation_pairs)
        # print('===============>all_pairs', all_pairs)       
        # for pairs have two entities
        for relation_pair in all_pairs:
    
            if not has_novelty:
                relation_label = neg_label
            else:
                relation_label = neg_label + '|None' # rel_type|novelty novelty => ['None', 'No', 'Novel']
            
            # print('=================>relation_pair', relation_pair)
            if not document.relation_pairs:
                # print('=================>no relation_pair', document.id)
                document.relation_pairs = {}
            
            if (relation_pair[0], relation_pair[1]) in document.relation_pairs:
                relation_label = document.relation_pairs[(relation_pair[0], relation_pair[1])]
                if pos_label != '' and (not has_novelty):
                    relation_label = pos_label
            elif (relation_pair[1], relation_pair[0]) in document.relation_pairs:
                relation_label = document.relation_pairs[(relation_pair[1], relation_pair[0])]
                if pos_label != '' and (not has_novelty):
                    relation_label = pos_label
            id1 = relation_pair[0]
            id2 = relation_pair[1]    
            id1type = relation_pair[2]
            id2type = relation_pair[3]
            
            tagged_sents = []
            all_sents_in_neighbors = []
            #all_sents_out_neighbors = []
            
            is_in_same_sent = False
            
            src_sent_ids = []
            tgt_sent_ids = []
                            
            token_offset = 0

            
            for sent_id, text_instance in enumerate(document.text_instances):
                #  labels: [O, O, O, O, ..., O], length = len(tokens)
                tokens, labels = convert_text_instance_2_iob2(text_instance, id1, id2, do_mask_other_nes)
                
                #print(' '.join(tokens))
                
                # in_neighbors_list, _ = get_in_neighbors_list(text_instance)
                in_neighbors_list = []
                in_neighbors_head_list = []
                for current_idx, (head, head_idx) in enumerate(zip(
                                                    text_instance.head,
                                                    text_instance.head_indexes)):
                    neighbors = []
                    neighbors_head = []
                    
                    neighbors.append(head_idx)
                    neighbors_head.append(head)
                    
                    in_neighbors_list.append(neighbors)
                    in_neighbors_head_list.append(neighbors_head)
                #out_neighbors_list, _ = get_out_neighbors_list(text_instance)
                
                # raise if neighbor is wrong
                if len(tokens) != len(in_neighbors_list):
                    print('==================>')
                    print('len(tokens)', len(tokens))
                    print('len(in_neighbors_list)', len(in_neighbors_list))
                    print('tokens', tokens)
                    print(document.id, sent_id)
                    in_neighbors_list = []
                #
                
                # check if Source and Target are in the same sentence
                if not is_in_same_sent:
                    is_Src_in = False
                    is_Tgt_in = False
                    for _label in labels:
                        if 'Src' in _label:
                            is_Src_in = True
                            src_sent_ids.append(sent_id)
                            break
                    for _label in labels:
                        if 'Tgt' in _label:
                            is_Tgt_in = True
                            tgt_sent_ids.append(sent_id)
                            break
                    if is_Src_in and is_Tgt_in:
                        is_in_same_sent = True
                #
                    
                
                tagged_sent, in_neighbors_str, token_offset =\
                    convert_iob2_to_tagged_sent(
                        tokens,
                        labels,
                        in_neighbors_list,
                        #out_neighbors_list,
                        token_offset,
                        to_mask_src_and_tgt,
                        has_end_tag)
                    
                only_co_occurrence_sent = False
                if only_co_occurrence_sent:
                    if is_in_same_sent:
                        tagged_sents.append(tagged_sent)
                        all_sents_in_neighbors.append(in_neighbors_str)
                else:
                    tagged_sents.append(tagged_sent)
                    all_sents_in_neighbors.append(in_neighbors_str)
                #all_sents_out_neighbors.append(out_neighbors_str)
                
            min_sents_window = 100
            for src_sent_id in src_sent_ids:
                for tgt_sent_id in tgt_sent_ids:
                    _min_sents_window = abs(src_sent_id - tgt_sent_id)
                    if _min_sents_window < min_sents_window:
                        min_sents_window = _min_sents_window
                        
            num_seq_lens.append(float(len(tagged_sent.split(' '))))


            #print('================>id1', id1)
            #print('================>all_sents_in_neighbors', all_sents_in_neighbors)
            
            out_sent = ' '.join(tagged_sents)
            
            if id1 == '-1' or id2 == '-1':
                continue
            if ' '.join(tagged_sents) == '':
                continue
            has_ne_type = True
            if has_ne_type:
                instance = document.id + '\t' +\
                            id1type + '\t' +\
                            id2type + '\t' +\
                            id1 + '\t' +\
                            id2 + '\t' +\
                            str(is_in_same_sent) + '\t' +\
                            str(min_sents_window) + '\t' +\
                            out_sent
                            #' '.join(all_sents_in_neighbors)
                        #' '.join(all_sents_in_neighbors) + '\t' +\
                        #' '.join(all_sents_out_neighbors)
            else:
                instance = document.id + '\t' +\
                            id1 + '\t' +\
                            id2 + '\t' +\
                            str(is_in_same_sent) + '\t' +\
                            str(min_sents_window) + '\t' +\
                            out_sent
                            #' '.join(all_sents_in_neighbors)
                
            if relation_label != neg_label:
                unique_YES_instances.add(instance)
            
            is_test_set = False
            if is_test_set or (id1 != '-' and id2 != '-'):
                if has_novelty:
                    relation_label = relation_label.replace('|', '\t')
                else:
                    relation_label = relation_label.split('|')[0]
                bert_writer.write(instance + '\t' + 
                                    relation_label + '\n')
                            
        number_unique_YES_instances += len(unique_YES_instances)
                
        bert_writer.flush()
        # break

    # average length of tagged sentences
    print(sum(num_seq_lens) / len(num_seq_lens))

29.54131731351024


In [294]:
instance.split("\t")

['24927617',
 'ChemicalEntity',
 'ChemicalEntity',
 'C486464',
 'D019821',
 'True',
 '0',
 "Rhabdomyolysis in a hepatitis C virus infected patient treated with @ChemicalEntitySrc$ telaprevir @/ChemicalEntitySrc$ and @ChemicalEntityTgt$ simvastatin @/ChemicalEntityTgt$ . A 46-year old man with a chronic hepatitis C virus infection received triple therapy with ribavirin , pegylated interferon and @ChemicalEntitySrc$ telaprevir @/ChemicalEntitySrc$ . The patient also received @ChemicalEntityTgt$ simvastatin @/ChemicalEntityTgt$ . One month after starting the antiviral therapy , the patient was admitted to the hospital because he developed rhabdomyolysis . At admission @ChemicalEntityTgt$ simvastatin @/ChemicalEntityTgt$ and all antiviral drugs were discontinued because toxicity due to a drug-drug interaction was suspected . The creatine kinase peaked at 62,246 IU/L and the patient was treated with intravenous normal saline . The patient 's renal function remained unaffected . Fourteen day

In [301]:
tgt_sent_ids

[0]

In [297]:
neg_label

'None'

In [296]:
relation_label

'Drug_Interaction\tNovel'

In [292]:
out_bert_file

'data/BioRED/ttrain.tsv'

In [None]:
has_novelty = False

# gen_biored_dataset(
#     in_data_dir  = in_data_dir,
#     out_data_dir = out_data_dir,
#     spacy_model  = spacy_model,
#     re_id_spliter_str = re_id_spliter_str,
#     normalized_type_dict = normalized_type_dict,
#     has_novelty          = True)



In [276]:
tagged_sent

'Hepatocyte nuclear factor-6 : associations between genetic variability and type II diabetes and between genetic variability and estimates of insulin secretion .'

In [285]:
tagged_sents

['Hepatocyte nuclear factor-6 : associations between genetic variability and type II diabetes and between genetic variability and estimates of insulin secretion .',
 'The transcription factor hepatocyte nuclear factor (HNF)-6 is an upstream regulator of several genes involved in the pathogenesis of maturity-onset diabetes of the young .',
 'We therefore tested the hypothesis that variability in the HNF-6 gene is associated with subsets of Type II ( non-insulin-dependent ) diabetes mellitus and estimates of insulin secretion in @ChemicalEntitySrc$ glucose @/ChemicalEntitySrc$ tolerant subjects .',
 'We cloned the coding region as well as the intron-exon boundaries of the HNF-6 gene . W',
 'e then examined them on genomic DNA in six MODY probands without mutations in the MODY1 , MODY3 and MODY4 genes and in 54 patients with late-onset Type II diabetes by combined single strand conformational polymorphism-heteroduplex analysis followed by direct sequencing of identified variants .',
 'An 

In [283]:
token_offset

22

In [135]:
_spacy_split_sentence(document.text_instances[1].text, nlp)

[(0, 170),
 (171, 385),
 (388, 474),
 (474, 767),
 (767, 865),
 (868, 932),
 (932, 1180),
 (1180, 1467),
 (1470, 1640)]

In [6]:
"""
Description of train data:
pmid
id1 type
id2 type
identifier(in many formats related to many database):
    Type | (Normalized component or identifier) | Database
    Gene: (19), NCBI Gene
    Variant: (p|SUB|S|276|T), dbSNP
    Variant: (RS#:2234671), dbSNP
    Species:(3175), NCBI Taxonomy
    Disease: (D003409), MEDIC (a combination of MESH and OMIM)
    CHemical: (D013726), MESH (Chemicals and Drugs Category)
    CellLine: (CVCL_1452), Cellosaurus
id1
id2
is_in_same_sent
min_sents_window
sentence
relation_label
novelty
"""

'\nDescription of train data:\n?\npmid,\nentity1 type\nentity2 type\n4 x \nidentifier(in many formats related to many database):\n    Type | (Normalized component or identifier) | Database\n    Gene: (19), NCBI Gene\n    Variant: (p|SUB|S|276|T), dbSNP\n    Variant: (RS#:2234671), dbSNP\n    Species:(3175), NCBI Taxonomy\n    Disease: (D003409), MEDIC (a combination of MESH and OMIM)\n    CHemical: (D013726), MESH (Chemicals and Drugs Category)\n    CellLine: (CVCL_1452), Cellosaurus\nhas_novelty: if the entity showed in abstract is novel\nneg_label\n?\ntext\nPositive_correlation\n'