# Biomedical Relationship Extraction from Scientific Literature

This notebook describes a baseline BERT model that can be trained to extract relationships from PubMed articles.

In [1]:
import sys, torch, logging

# fix random seed
torch.manual_seed(0)

# MacOS Metal Framework (MPS)
device = torch.device("cpu")

# log level for experiment
logger = logging.getLogger("BioRE")

# code for the baseline model
sys.path.append("./baseline/src")

## Batch processing of sequences and relations

In [2]:
from transformers import AutoTokenizer
from module.data_loader import Dataloader

tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', use_fast=True)

# 
train_loader = Dataloader('./baseline/data', tokenizer, batch_size=2,
        max_text_length=512, training=True, logger=logger, lowercase=True)

100%|██████████| 76942/76942 [03:16<00:00, 391.67it/s] 
100%|██████████| 1521/1521 [00:05<00:00, 267.43it/s]
100%|██████████| 1939/1939 [00:07<00:00, 251.95it/s]
100%|██████████| 523/523 [00:02<00:00, 255.79it/s]
100%|██████████| 523/523 [00:02<00:00, 256.29it/s]


In [3]:
(num, return_data) = next(iter(train_loader))

(input_ids, attention_mask, ep_masks, e1_indicators, e2_indicators, label_arrays) = return_data

One way to approach this might be to pad the number of entity pairs to *max_num_ep*. This would limit the number of relationships that can be retrieved per document, but would allow us to process multiple sequences per batch and *not waste valuable GPU cycles*.

In [4]:
input_ids = input_ids.to(device)
input_ids

tensor([[    2,  3189,  5527,  ...,     0,     0,     3],
        [    2, 21651,  4507,  ...,     0,     0,     3]])

In [5]:
attention_mask = attention_mask.to(device)
attention_mask

tensor([[1, 1, 1,  ..., 0, 0, 1],
        [1, 1, 1,  ..., 0, 0, 1]])

In [6]:
ep_masks = ep_masks.to(device)
labels = label_arrays.to(device)

In [7]:
# TODO: put this in the data loader
ep_masks = ep_masks.unsqueeze(4)

## Constructing a baseline BERT model

In [8]:
from torchinfo import summary
from module.model import Model

model = Model({'data_path': './baseline/data', 'learning_rate': 1e-05, 'mode': 'train', 'encoder_type': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
          'model': 'biaffine', 'output_path': '', 'load_path': '', 'multi_label': True, 'grad_accumulation_steps': 16, 'max_text_length': 512, 
          'dim': 128, 'weight_decay': 0.0001, 'dropout_rate': 0.1, 'max_grad_norm': 10.0, 'epochs': 10, 'patience': 5, 'log_interval': 0.25, 
          'warmup': -1.0, 'cuda': True})

# nn.Module
model.to(device)

summary(model, input_size=[(2, 512), (2, 512)], dtypes=['torch.IntTensor', 'torch.IntTensor'])

Orthogonal pretrainer loss: 3.20e-12


Layer (type:depth-idx)                                  Output Shape              Param #
Model                                                   [2, 1, 512, 512, 15]      245,760
├─BertModel: 1-1                                        [2, 768]                  --
│    └─BertEmbeddings: 2-1                              [2, 512, 768]             --
│    │    └─Embedding: 3-1                              [2, 512, 768]             23,440,896
│    │    └─Embedding: 3-2                              [2, 512, 768]             1,536
│    │    └─Embedding: 3-3                              [1, 512, 768]             393,216
│    │    └─LayerNorm: 3-4                              [2, 512, 768]             1,536
│    │    └─Dropout: 3-5                                [2, 512, 768]             --
│    └─BertEncoder: 2-2                                 [2, 512, 768]             --
│    │    └─ModuleList: 3-6                             --                        85,054,464
│    └─BertPooler: 2-3      

## Overfitting one batch of training data

In [9]:
# Adam with integrated weight decay regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-05,
                  weight_decay=0.0001, eps=1e-8)

# y is 1 or 0, x is 1-d logit
bcelogitloss = torch.nn.BCEWithLogitsLoss() 

In [10]:
# 
model.encoder.to(device)

#
model.head_layer0.to(device)
model.head_layer1.to(device)
model.tail_layer0.to(device)
model.tail_layer1.to(device)

#
model.biaffine_mat = torch.nn.Parameter(model.biaffine_mat.to(device))

In [11]:
from tqdm import tqdm

epochs = 10
training_loss = []

# training a single batch
for i in tqdm(range(0, epochs), desc="Training"):
    
    # reset gradients
    optimizer.zero_grad()
    
    # pairwise relations (N, max_length, max_length, R+1)
    pairwise_scores = model(input_ids, attention_mask)
    
    # broadcast predictions (N, 1, max_num_eps, R+1)
    pairwise_scores = pairwise_scores + ep_masks
    
    # 
    pairwise_scores = torch.logsumexp(pairwise_scores, dim=[2,3])
    
    # multi-label logits (N, num_ep, R)
    scores = pairwise_scores[:, :, :-1]
    
    # sigmoid and binary-cross entropy
    loss = bcelogitloss(scores, labels)
    loss.backward()
    
    # change weights
    optimizer.step()
    
    # 
    training_loss.append(loss.item())

Training: 100%|██████████| 10/10 [00:52<00:00,  5.24s/it]


In [12]:
training_loss

[0.8361715078353882,
 0.6210440993309021,
 0.4343566596508026,
 0.30835989117622375,
 0.18467891216278076,
 0.11713247746229172,
 0.06534256786108017,
 0.045416899025440216,
 0.027115827426314354,
 0.01943390630185604]

In [None]:
np.array(tokenizer.convert_ids_to_tokens(input_ids[0].numpy()))

In [18]:
# TODO: visualize prediction
y_pred = torch.sigmoid(scores)