# Biomedical Relation Extraction from Scientific Literature

This notebook describes a baseline BERT model that can be trained to extract relationships from PubMed articles.

In [1]:
import sys, torch, logging

# fix random seed
torch.manual_seed(0)

# CUDA device if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")

# log level for experiment
logger = logging.getLogger("BioRE")

# code for the baseline model
sys.path.append("./baseline/src")

## Batch processing of sequences and relations

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from module.data_loader import Dataloader

tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', use_fast=True)

# 
train_loader = Dataloader('./baseline/data', tokenizer, batch_size=2,
        max_text_length=512, training=True, logger=logger, lowercase=True)

100%|██████████| 76942/76942 [05:35<00:00, 229.06it/s]
100%|██████████| 1521/1521 [00:08<00:00, 178.18it/s]
100%|██████████| 1939/1939 [00:13<00:00, 142.98it/s]
100%|██████████| 523/523 [00:03<00:00, 145.47it/s]
100%|██████████| 523/523 [00:03<00:00, 153.92it/s]


In [3]:
(num, return_data) = next(iter(train_loader))

(input_ids, attention_mask, ep_masks, e1_indicators, e2_indicators, label_arrays) = return_data

One way to approach this might be to pad the number of entity pairs to *max_num_ep*. This would limit the number of relationships that can be retrieved per document, but would allow us to process multiple sequences per batch and *not waste valuable GPU cycles*.

In [4]:
input_ids = input_ids.to(device)
input_ids

tensor([[    2,  3189,  5527,  ...,     0,     0,     3],
        [    2, 21651,  4507,  ...,     0,     0,     3]], device='cuda:0')

In [5]:
attention_mask = attention_mask.to(device)
attention_mask

tensor([[1, 1, 1,  ..., 0, 0, 1],
        [1, 1, 1,  ..., 0, 0, 1]], device='cuda:0')

In [6]:
ep_masks = ep_masks.to(device)
labels = label_arrays.to(device)

In [7]:
# TODO: put this in the data loader
ep_masks = ep_masks.unsqueeze(4)

## Constructing a baseline BERT model

In [8]:
from torchinfo import summary
from module.model import Model

# 
baseline = Model({'data_path': './baseline/data', 'learning_rate': 1e-05, 'mode': 'train', 'encoder_type': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
          'model': 'biaffine', 'output_path': '', 'load_path': '', 'multi_label': True, 'grad_accumulation_steps': 16, 'max_text_length': 512, 
          'dim': 128, 'weight_decay': 0.0001, 'dropout_rate': 0.1, 'max_grad_norm': 10.0, 'epochs': 10, 'patience': 5, 'log_interval': 0.25, 
          'warmup': -1.0, 'cuda': True})

#
pubmedbert = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")


summary(baseline, input_size=[(2, 512), (2, 512)], dtypes=['torch.IntTensor', 'torch.IntTensor'], device="cpu")

Orthogonal pretrainer loss: 2.23e-12


Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Layer (type:depth-idx)                                  Output Shape              Param #
Model                                                   [2, 1, 512, 512, 15]      245,760
├─BertModel: 1-1                                        [2, 768]                  --
│    └─BertEmbeddings: 2-1                              [2, 512, 768]             --
│    │    └─Embedding: 3-1                              [2, 512, 768]             23,440,896
│    │    └─Embedding: 3-2                              [2, 512, 768]             1,536
│    │    └─Embedding: 3-3                              [1, 512, 768]             393,216
│    │    └─LayerNorm: 3-4                              [2, 512, 768]             1,536
│    │    └─Dropout: 3-5                                [2, 512, 768]             --
│    └─BertEncoder: 2-2                                 [2, 512, 768]             --
│    │    └─ModuleList: 3-6                             --                        85,054,464
│    └─BertPooler: 2-3      

## Overfitting one batch of training data

In [9]:
# Adam with integrated weight decay regularization
optimizer = torch.optim.AdamW(baseline.parameters(), lr=1e-05,
                  weight_decay=0.0001, eps=1e-8)

# y is 1 or 0, x is 1-d logit
bcelogitloss = torch.nn.BCEWithLogitsLoss() 

In [10]:
# 
baseline.encoder.to(device)

#
baseline.head_layer0.to(device)
baseline.head_layer1.to(device)
baseline.tail_layer0.to(device)
baseline.tail_layer1.to(device)

#
baseline.biaffine_mat = torch.nn.Parameter(baseline.biaffine_mat.to(device))

In [11]:
from tqdm import tqdm

epochs = 10
training_loss = []

# training a single batch
for i in tqdm(range(0, epochs), desc="Training"):
    
    # reset gradients
    optimizer.zero_grad()
    
    # pairwise relations (N, max_length, max_length, R+1)
    pairwise_scores = baseline(input_ids, attention_mask)
    
    # broadcast predictions (N, 1, max_num_eps, R+1)
    pairwise_scores = pairwise_scores + ep_masks
    
    # 
    pairwise_scores = torch.logsumexp(pairwise_scores, dim=[2,3])
    
    # multi-label logits (N, num_ep, R)
    scores = pairwise_scores[:, :, :-1]
    
    # sigmoid and binary-cross entropy
    loss = bcelogitloss(scores, labels)
    loss.backward()
    
    # change weights
    optimizer.step()
    
    # 
    training_loss.append(loss.item())

Training: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]


In [12]:
training_loss

[1.0754338502883911,
 0.8323897123336792,
 0.6189757585525513,
 0.43950217962265015,
 0.3031347393989563,
 0.2040204554796219,
 0.12550131976604462,
 0.07640735805034637,
 0.04648725315928459,
 0.031422968953847885]

In [15]:
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())

In [16]:
# TODO: visualize prediction
y_pred = torch.sigmoid(scores)

## Training one epoch on biochemical relations

preprocess training data ready to be sent to the GPU in time

In [None]:
train = []
for batch_num, return_data in enumerate(train_loader):
    (input_ids, attention_mask, ep_masks, e1_indicators, e2_indicators, label_arrays) = return_data[1]
    train.append(input_ids)

    # define training loop