# Biomedical Relation Extraction from Scientific Literature

Baseline BERT model to extract relationships from PubMed articles.

In [1]:
import sys, torch, logging

# fix random seed
torch.manual_seed(0)

# CUDA device if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")

# log level for experiment
logger = logging.getLogger("BioRE")

# code for the baseline model
sys.path.append("./baseline/src")

In [2]:
import wandb

# experiment tracking
wandb.login()

run = wandb.init(
    # Set the project where this run will be logged
    project="biomed-bert-re",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": 1e-05,
        'weight_decay': 0.0001,
        'dropout_rate': 0.1,
        "architecture": "BRAN",
        "dataset": "ChemDisGene",
        "epochs": 100,
    }
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mramonreszat[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Batch processing of sequences and relations

In [2]:
from module.data_loader import Dataloader
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', use_fast=True)
chemdisgene = Dataloader('./baseline/data', tokenizer, training=False, logger=logger, lowercase=True)

100%|██████████| 1521/1521 [00:04<00:00, 305.96it/s]
100%|██████████| 1939/1939 [00:07<00:00, 265.62it/s]
100%|██████████| 523/523 [00:02<00:00, 257.89it/s]
100%|██████████| 523/523 [00:01<00:00, 264.01it/s]


In [3]:
chemdisgene.val[0].keys()

dict_keys(['input', 'pad', 'docid', 'input_length', 'label_vectors', 'label_names', 'e1_indicators', 'e2_indicators', 'e1s', 'e2s', 'e1_types', 'e2_types'])

## Constructing a baseline BERT model

In [4]:
from torchinfo import summary
from module.model import Model

config = {'data_path': './baseline/data', 'learning_rate': 1e-05, 'mode': 'train', 'encoder_type': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
          'model': 'biaffine', 'output_path': '', 'load_path': '', 'multi_label': True, 'grad_accumulation_steps': 16, 'max_text_length': 512, 
          'dim': 128, 'weight_decay': 0.0001, 'dropout_rate': 0.1, 'max_grad_norm': 10.0, 'epochs': 10, 'patience': 5, 'log_interval': 0.25, 
          'warmup': -1.0, 'cuda': True}

model = Model(config)

summary(model, input_size=[(2, 512), (2, 512)], dtypes=['torch.IntTensor', 'torch.IntTensor'], device="cpu")

Orthogonal pretrainer loss: 4.67e-10


Layer (type:depth-idx)                                  Output Shape              Param #
Model                                                   [2, 1, 512, 512, 15]      245,760
├─BertModel: 1-1                                        [2, 768]                  --
│    └─BertEmbeddings: 2-1                              [2, 512, 768]             --
│    │    └─Embedding: 3-1                              [2, 512, 768]             23,440,896
│    │    └─Embedding: 3-2                              [2, 512, 768]             1,536
│    │    └─Embedding: 3-3                              [1, 512, 768]             393,216
│    │    └─LayerNorm: 3-4                              [2, 512, 768]             1,536
│    │    └─Dropout: 3-5                                [2, 512, 768]             --
│    └─BertEncoder: 2-2                                 [2, 512, 768]             --
│    │    └─ModuleList: 3-6                             --                        85,054,464
│    └─BertPooler: 2-3      

In [None]:
#
pubmedbert = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")

## Training one epoch on biochemical relations

Preload training data to send them to GPU

In [5]:
# Adam with integrated weight decay regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-05,
                  weight_decay=0.0001, eps=1e-8)

# y is 1 or 0, x is 1-d logit
criterion = torch.nn.BCEWithLogitsLoss() 

In [5]:
model.encoder.to(device)

model.head_layer0.to(device)
model.head_layer1.to(device)
model.tail_layer0.to(device)
model.tail_layer1.to(device)

model.biaffine_mat = torch.nn.Parameter(model.biaffine_mat.to(device))

In [7]:
import psutil

train_dataset = []
for batch_num, return_data in enumerate(chemdisgene):

    # Get the virtual memory status
    memory_info = psutil.virtual_memory()

    # Convert used memory from bytes to GB
    used_memory_gb = memory_info.used / (1024 ** 3)

    train_dataset.append(return_data[1])

    if used_memory_gb>=24: # Break if more than 24 GB is collected
        break

    if batch_num>=10000: # Break if more than 1000 batches are collected
        break

In [6]:
def model_forward(input_ids, attention_mask, ep_masks):
    pairwise_scores = model(input_ids, attention_mask)
    ep_masks = ep_masks.unsqueeze(4)
    pairwise_scores = pairwise_scores + ep_masks
    pairwise_scores = torch.logsumexp(pairwise_scores, dim=[2,3])
    outputs = pairwise_scores[:, :, :-1]
    return outputs

In [None]:
import numpy as np
from tqdm import tqdm
from sklearn import metrics

for epoch in tqdm(range(wandb.config.epochs), desc="Training"):
    model.train()
    train_loss = 0.0

    # training the model
    for batch_idx, batch in tqdm(enumerate(chemdisgene)):
        (input_ids, attention_mask, ep_masks, e1_indicators, e2_indicators, label_arrays) = batch[0]

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        ep_masks = ep_masks.to(device)
        labels = label_arrays.to(device)

        # reset gradient
        optimizer.zero_grad()

        # predict the relationship types between entity pairs
        scores = model_forward(input_ids, attention_mask, ep_masks)

        # binary cross entropy loss
        loss = criterion(scores, labels)
        train_loss += loss.item()

        loss.backward() 

        optimizer.step()

        # track the loss of each training example for debugging
        wandb.log({"batch": batch_idx, "batch_loss": loss.item()})
    
    train_loss /= len(chemdisgene.train)
    wandb.log({"epoch": epoch, "loss": train_loss})

    with torch.no_grad():
        scores, labels = [], []
        # 
        for sample_idx, data in enumerate(chemdisgene.val):
            # 
            input_ids = torch.tensor(data["input"]).to(device)
            attention_mask = torch.tensor(data["pad"]).to(device)

            e1_indicators_ = np.array(data["e1_indicators"])
            e2_indicators_ = np.array(data["e2_indicators"])

            ep_masks_ = []
            for e1_indicator, e2_indicator in list(zip(list(e1_indicators_), list(e2_indicators_))):
                    ep_mask_ = np.full(
                            (512, 512), -1e20)
                    ep_outer = 1 - np.outer(e1_indicator, e2_indicator)
                    ep_mask_ = ep_mask_ * ep_outer
                    ep_masks_.append(ep_mask_)
            ep_masks_ = np.array(ep_masks_)

            ep_masks = torch.tensor(
                    np.array(ep_masks_), dtype=torch.float32).to(device)
            label_array = torch.tensor(
                    np.array(data["label_vectors"]), dtype=torch.float32)
            
            # 
            pairwise = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
            pairwise = pairwise + ep_masks.unsqueeze(0).unsqueeze(4)
            pairwise = torch.logsumexp(pairwise, dim=[2,3])
            score = pairwise[:, :, :-1]

            #
            score = score.detach().cpu().numpy().squeeze(axis=0)
            prediction = (score > np.zeros(14))

            for j in range(len(prediction)):
                predict_names = []
                for k in list(np.where(prediction[j] == 1)[0]):
                        predict_names.append(
                        chemdisgene.relation_name[k])
                label_names = []
                for k in list(np.where(label_array[j] == 1)[0]):
                        label_names.append(chemdisgene.relation_name[k])
                score_dict = {}
                for k, scr in enumerate(list(scores[j])):
                        if k not in chemdisgene.relation_name:
                                score_dict["NA"] = float(scr)
                        else:
                                score_dict[chemdisgene.relation_name[k]] = float(
                                    scr)
            # 
            wandb.log({"sample_idx": sample_idx, "docid": data['docid'], "e1s": data['e1s'], 
                       "e2s": data['e2s'], "label_names": label_names, "predictions": predict_names,
                       "scores": score_dict})    
        
        scores = np.concatenate(scores, axis=0)
        labels = np.concatenate(labels, axis=0)

        average_precision = metrics.average_precision_score(
                labels.flatten(), scores.flatten())
        wandb.log({"epoch": epoch, "average_precision": average_precision})
        
        

        results = self.calculate_metrics(
                predictions, predictions_categ, labels)

In [54]:
import numpy as np
from sklearn import metrics

scores = []
labels = []

for sample_idx, data in enumerate(chemdisgene.val[0:2]):
    # Input data tensors
    input_ids = torch.tensor(data["input"]).to(device)
    attention_mask = torch.tensor(data["pad"]).to(device)

    e1_indicators_ = np.array(data["e1_indicators"])
    e2_indicators_ = np.array(data["e2_indicators"])

    ep_masks_ = []
    for e1_indicator, e2_indicator in list(zip(list(e1_indicators_), list(e2_indicators_))):
        ep_mask_ = np.full((512, 512), -1e20)
        ep_outer = 1 - np.outer(e1_indicator, e2_indicator)
        ep_mask_ = ep_mask_ * ep_outer
        ep_masks_.append(ep_mask_)
    ep_masks_ = np.array(ep_masks_)

    ep_masks = torch.tensor(np.array(ep_masks_), dtype=torch.float32).to(device)
    label_array = torch.tensor(np.array(data["label_vectors"]), dtype=torch.float32)

    # Model prediction
    pairwise = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
    pairwise = pairwise + ep_masks.unsqueeze(0).unsqueeze(4)
    pairwise = torch.logsumexp(pairwise, dim=[2, 3])
    score = pairwise[:, :, :-1]

    # Process and extract predictions
    score = score.detach().cpu().numpy().squeeze(axis=0)
    label = label_array.cpu().numpy()
    prediction = (score > np.zeros(14))

    scores.append(score)
    labels.append(label)

    for j in range(len(prediction)):
        predict_names = []
        for k in list(np.where(prediction[j] == 1)[0]):
            predict_names.append(chemdisgene.relation_name[k])
        label_names = []
        for k in list(np.where(label_array[j] == 1)[0]):
            label_names.append(chemdisgene.relation_name[k])
        score_dict = {}
        for k, scr in enumerate(list(score[j])):
            if k not in chemdisgene.relation_name:
                score_dict["NA"] = float(scr)
            else:
                score_dict[chemdisgene.relation_name[k]] = float(scr)
    
    print({"sample_idx": sample_idx, "docid": data['docid'], "e1s": data['e1s'],
               "e2s": data['e2s'], "label_names": label_names, "predictions": predict_names,
               "scores": score_dict})
    
scores = np.concatenate(scores, axis=0)
labels = np.concatenate(labels, axis=0)

average_precision = metrics.average_precision_score(
    labels.flatten(), scores.flatten())
print({"epoch": 0, "average_precision": average_precision})



#results = self.calculate_metrics(
#    predictions, predictions_categ, labels)

{'sample_idx': 0, 'docid': '26583456', 'e1s': ['MESH:D005283'], 'e2s': ['MESH:D012131'], 'label_names': [], 'predictions': ['chem_disease:marker/mechanism', 'chem_disease:therapeutic', 'chem_gene:increases^expression', 'chem_gene:decreases^expression', 'gene_disease:marker/mechanism', 'chem_gene:increases^activity', 'chem_gene:decreases^activity', 'chem_gene:increases^metabolic_processing', 'chem_gene:affects^binding', 'chem_gene:increases^transport', 'chem_gene:decreases^metabolic_processing', 'chem_gene:affects^localization', 'chem_gene:affects^expression', 'gene_disease:therapeutic'], 'scores': {'chem_disease:marker/mechanism': 0.6742985248565674, 'chem_disease:therapeutic': 0.6742985248565674, 'chem_gene:increases^expression': 0.6742985248565674, 'chem_gene:decreases^expression': 0.6742985248565674, 'gene_disease:marker/mechanism': 0.6742985248565674, 'chem_gene:increases^activity': 0.6742985248565674, 'chem_gene:decreases^activity': 0.6742985248565674, 'chem_gene:increases^metabol

In [11]:
wandb.finish()



0,1
batch,▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█
batch_loss,▅▆▂▁▂▁▁▇▆▁█▄▂▆▃▁▇▂█▄▂▁▅▃▅▁▂▃▄▂▂▂▇▂▂▁▂▂▁▂
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch,1637.0
batch_loss,0.00664
epoch,99.0
loss,0.00124
