In [2]:
import torch
import pickle
from tqdm import tqdm
import json
# from transformers import (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
#                           XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP)
# from transformers import BertModel, BertTokenizer
from transformers import AutoModel, AutoTokenizer
import argparse
import os
import numpy as np

def check_path(path):
    d = os.path.dirname(path)
    if not os.path.exists(d):
        os.makedirs(d)

In [None]:
# MODEL_CLASS_TO_NAME = {
#     'gpt': list(OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
#     'bert': list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
#     'xlnet': list(XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
#     'roberta': list(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
#     'lstm': ['lstm'],
# }
model_class = AutoModel
tokenizer_class = AutoTokenizer
model_name = 'michiyasunaga/BioLinkBERT-large'
tokenizer = tokenizer_class.from_pretrained(model_name)
model = model_class.from_pretrained(model_name, output_hidden_states=True)
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
outdir = '../data/umls/encoder_inputs/'
outfile = 'BioLinkBERT-inputs'

umls_csv_path = './umls.csv'
umls_vocab_path = './concept_names.txt'
umls_rel_path = './relations.txt'

concept2id = {}
id2concept = {}
with open(umls_vocab_path, "r", encoding="utf8") as fin:
    id2concept = {w.strip().split('\t')[0]: w.strip().split('\t')[1] for w in fin}
concept2id = {w: i for i, w in enumerate(id2concept)}

id2relation = [rel.strip() for rel in open(umls_rel_path)]
relation2id = {r: i for i, r in enumerate(id2relation)}
spaced_relations = [rel.replace('_', ' ') for rel in id2relation]

max_seq_length = 512

In [9]:
# BERT inputs
all_input_ids, all_input_mask, all_segment_ids, all_sub_span, all_rel_span, all_obj_span= [], [], [], [], [], []

loops = 0
repeats = 0
# keep track of input entity ids
nrow = sum(1 for _ in open(umls_csv_path, 'r', encoding='utf-8'))
entity_ids = []  # stores [sub, rel, obj] for each triple

with open(umls_csv_path, "r", encoding="utf8") as fin:
    
    attrs = set()
    
    for line in tqdm(fin, total=nrow):
        ls = line.strip().split('\t')
        rel = relation2id[ls[0]] if not(ls[0] == 'isa') else relation2id['is_a']  
        subj = ls[1]
        obj = ls[2]
        
        if subj == obj:  # delete loops
            loops = loops + 1
            continue

        if (subj, obj, rel) not in attrs:
            attrs.add((subj, obj, rel))
            entity_ids.append([subj, rel, obj])
            # tokenize inputs and format input data for BERT
            sub_tokens = tokenizer.tokenize(id2concept[subj])
            rel_tokens = tokenizer.tokenize(id2relation[rel].replace('_', ' '))
            obj_tokens = tokenizer.tokenize(id2concept[obj])
            triple_tokens = [tokenizer.cls_token] + sub_tokens + rel_tokens + obj_tokens + [tokenizer.sep_token]
            input_ids = tokenizer.convert_tokens_to_ids(triple_tokens)

            assert len(input_ids) <= max_seq_length
            pad_len = max_seq_length - len(input_ids)
            input_mask = [1] * len(input_ids) + [0] * pad_len
            input_ids += [0] * pad_len
            segment_ids = [0] * max_seq_length # all just one sentence (no sentence pair)
            # define span of sub, rel, and obj
            sub_span = [1, len(sub_tokens)]  
            rel_span = [sub_span[-1] + 1, sub_span[-1] + len(rel_tokens)]
            obj_span = [rel_span[-1] + 1, rel_span[-1] + len(obj_tokens)]
            
            all_input_ids.append(input_ids)
            all_input_mask.append(input_mask)
            all_segment_ids.append(segment_ids)
            all_sub_span.append(sub_span)
            all_rel_span.append(rel_span)
            all_obj_span.append(obj_span)
        else:
            repeats = repeats + 1
        
entity_ids = np.array(entity_ids)

cache_path = outdir+outfile+'.pkl'
check_path(cache_path)

with open(cache_path, 'wb') as fout:
    pickle.dump((all_input_ids, all_input_mask, all_segment_ids, all_sub_span, all_rel_span, all_obj_span, entity_ids), fout)
print('Inputs dumped')

# should be number of rows in umls.csv (1212586)
print(loops + repeats + len(all_input_ids))

100%|██████████| 1212586/1212586 [02:14<00:00, 9044.71it/s] 


Inputs dumped


In [8]:
cache_path = outdir+outfile+'.pkl'

with open(cache_path, 'rb') as fin:
    all_input_ids, all_input_mask, all_segment_ids, all_sub_span, all_rel_span, all_obj_span, entity_ids = pickle.load(fin)

all_input_ids, all_input_mask, all_segment_ids, all_sub_span, all_rel_span, all_obj_span = [torch.tensor(x, dtype=torch.long) for x in [all_input_ids, all_input_mask, all_segment_ids, all_sub_span, all_rel_span, all_obj_span]]
model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-23): 24 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inpl

In [None]:
# setup loop variables
n = entity_ids.shape[0]
batch_size = 64

assert n == all_input_ids.shape[0]

# which layer of BERT to use for embeddings
layer = -1
emb_dim = 1024
umls_concept_emb = torch.zeros((len(concept2id), emb_dim)).to(device)
umls_rel_emb = torch.zeros((len(relation2id), emb_dim)).to(device)

with torch.no_grad():
    mask = torch.arange(max_seq_length, device=device)[None, :]

    for a in tqdm(range(0, n, batch_size), total=n // batch_size + 1, desc='Extracting features'):
        b = min(a + batch_size, n)
        *batch, sub_span, rel_span, obj_span = [x.to(device) for x in [all_input_ids[a:b], all_input_mask[a:b], all_segment_ids[a:b], all_sub_span[a:b], all_rel_span[a:b], all_obj_span[a:b]]]
        outputs = model(*batch)
        
        hidden_states = outputs[-1][layer]
        
        sub_mask = (mask >= sub_span[:, 0, None]) & (mask <= sub_span[:, 1, None])
        rel_mask = (mask >= rel_span[:, 0, None]) & (mask <= rel_span[:, 1, None])
        obj_mask = (mask >= obj_span[:, 0, None]) & (mask <= obj_span[:, 1, None])
        # apply mask using the spans, and average the token rep by dividing by span length
        sub_pooled = (hidden_states * sub_mask.float().unsqueeze(-1)).sum(1) / (sub_span[:,1].float() - sub_span[:,0].float() + 1).unsqueeze(1)
        rel_pooled = (hidden_states * rel_mask.float().unsqueeze(-1)).sum(1) / (rel_span[:,1].float() - rel_span[:,0].float() + 1).unsqueeze(1)
        obj_pooled = (hidden_states * obj_mask.float().unsqueeze(-1)).sum(1) / (obj_span[:,1].float() - obj_span[:,0].float() + 1).unsqueeze(1)
        
        sub_ids = entity_ids[a:b, 0]
        rel_ids = entity_ids[a:b, 1]
        obj_ids = entity_ids[a:b, 2]
        for i, (sub_id, rel_id, obj_id) in enumerate(zip(sub_ids, rel_ids, obj_ids)):
            umls_concept_emb[concept2id[sub_id]] += sub_pooled[i]
            umls_concept_emb[concept2id[obj_id]] += obj_pooled[i]
            umls_rel_emb[int(rel_id)] += rel_pooled[i]

output_dir = '../data/umls/encoder_embs/'
check_path(output_dir)
# np.save(output_dir + 'umls_concept_embs_not_counted', umls_concept_emb.to('cpu').numpy())
# np.save(output_dir + 'umls_relation_embs_not_counted', umls_rel_emb.to('cpu').numpy())

prune = False
sub_unique, sub_counts = np.unique([concept2id[id] for id in entity_ids[:,0]], return_counts=True) 
rel_unique, rel_counts_unordered = np.unique(entity_ids[:,1].astype(int), return_counts=True) 
obj_unique, obj_counts = np.unique([concept2id[id] for id in entity_ids[:,2]], return_counts=True) 
rel_counts = np.zeros(len(id2relation))
rel_counts[rel_unique] += rel_counts_unordered
concept_counts = np.zeros(len(concept2id))
concept_counts[sub_unique] += sub_counts
concept_counts[obj_unique] += obj_counts 

umls_concept_emb = umls_concept_emb.to('cpu').numpy()
umls_rel_emb = umls_rel_emb.to('cpu').numpy()
umls_concept_emb = np.divide(umls_concept_emb, concept_counts[:,np.newaxis], out=np.zeros_like(umls_concept_emb), where=concept_counts[:,np.newaxis]!=0)
umls_rel_emb = umls_rel_emb / rel_counts[:,np.newaxis]
np.save(output_dir + 'bert-large_umls_concept_embs', umls_concept_emb)
np.save(output_dir + 'bert-large_umls_relation_embs', umls_rel_emb)

Extracting features: 100%|██████████| 18290/18290 [19:32:13<00:00,  3.85s/it]   
