# Biomedical Relation Extraction from Scientific Literature

Baseline BERT model to extract relationships from PubMed articles.

In [1]:
import sys, torch, logging

# fix random seed
torch.manual_seed(0)

# CUDA device if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")

# log level for experiment
logger = logging.getLogger("BioRE")

# code for the baseline model
sys.path.append("./baseline/src")

In [2]:
import wandb

# experiment tracking
wandb.login()

run = wandb.init(
    # Set the project where this run will be logged
    project="biomed-bert-re",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": 1e-05,
        'weight_decay': 0.0001,
        'dropout_rate': 0.1,
        "architecture": "BRAN",
        "dataset": "ChemDisGene",
        "epochs": 100,
    }
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mramonreszat[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Batch processing of sequences and relations

In [2]:
from module.data_loader import Dataloader
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract', use_fast=True)
chemdisgene = Dataloader('./baseline/data', tokenizer, training=False, logger=logger, lowercase=True)

100%|██████████| 1521/1521 [00:05<00:00, 298.27it/s]
100%|██████████| 1939/1939 [00:07<00:00, 260.66it/s]
100%|██████████| 523/523 [00:02<00:00, 255.22it/s]
100%|██████████| 523/523 [00:02<00:00, 252.65it/s]


In [3]:
chemdisgene.val[0].keys()

dict_keys(['input', 'pad', 'docid', 'input_length', 'label_vectors', 'label_names', 'e1_indicators', 'e2_indicators', 'e1s', 'e2s', 'e1_types', 'e2_types'])

## Constructing a baseline BERT model

In [4]:
from torchinfo import summary
from module.model import Model

config = {'data_path': './baseline/data', 'learning_rate': 1e-05, 'mode': 'train', 'encoder_type': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
          'model': 'biaffine', 'output_path': '', 'load_path': '', 'multi_label': True, 'grad_accumulation_steps': 16, 'max_text_length': 512, 
          'dim': 128, 'weight_decay': 0.0001, 'dropout_rate': 0.1, 'max_grad_norm': 10.0, 'epochs': 10, 'patience': 5, 'log_interval': 0.25, 
          'warmup': -1.0, 'cuda': True}

model = Model(config)

summary(model, input_size=[(2, 512), (2, 512)], dtypes=['torch.IntTensor', 'torch.IntTensor'], device="cpu")

Orthogonal pretrainer loss: 8.07e-10


Layer (type:depth-idx)                                  Output Shape              Param #
Model                                                   [2, 1, 512, 512, 15]      245,760
├─BertModel: 1-1                                        [2, 768]                  --
│    └─BertEmbeddings: 2-1                              [2, 512, 768]             --
│    │    └─Embedding: 3-1                              [2, 512, 768]             23,440,896
│    │    └─Embedding: 3-2                              [2, 512, 768]             1,536
│    │    └─Embedding: 3-3                              [1, 512, 768]             393,216
│    │    └─LayerNorm: 3-4                              [2, 512, 768]             1,536
│    │    └─Dropout: 3-5                                [2, 512, 768]             --
│    └─BertEncoder: 2-2                                 [2, 512, 768]             --
│    │    └─ModuleList: 3-6                             --                        85,054,464
│    └─BertPooler: 2-3      

In [None]:
#
pubmedbert = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")

## Training one epoch on biochemical relations

Preload training data to send them to GPU

In [5]:
# Adam with integrated weight decay regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-05,
                  weight_decay=0.0001, eps=1e-8)

# y is 1 or 0, x is 1-d logit
criterion = torch.nn.BCEWithLogitsLoss() 

In [9]:
model.encoder.to(device)

model.head_layer0.to(device)
model.head_layer1.to(device)
model.tail_layer0.to(device)
model.tail_layer1.to(device)

model.biaffine_mat = torch.nn.Parameter(model.biaffine_mat.to(device))

In [7]:
import psutil

train_dataset = []
for batch_num, return_data in enumerate(chemdisgene):

    # Get the virtual memory status
    memory_info = psutil.virtual_memory()

    # Convert used memory from bytes to GB
    used_memory_gb = memory_info.used / (1024 ** 3)

    train_dataset.append(return_data[1])

    if used_memory_gb>=24: # Break if more than 24 GB is collected
        break

    if batch_num>=10000: # Break if more than 1000 batches are collected
        break

In [7]:
def model_forward(input_ids, attention_mask, ep_masks):
    pairwise_scores = model(input_ids, attention_mask)
    ep_masks = ep_masks.unsqueeze(4)
    pairwise_scores = pairwise_scores + ep_masks
    pairwise_scores = torch.logsumexp(pairwise_scores, dim=[2,3])
    outputs = pairwise_scores[:, :, :-1]
    return outputs

In [None]:
import numpy as np
from tqdm import tqdm
from sklearn import metrics

for epoch in tqdm(range(wandb.config.epochs), desc="Training"):
    model.train()
    train_loss = 0.0

    # training the model
    for batch_idx, batch in tqdm(enumerate(chemdisgene)):
        (input_ids, attention_mask, ep_masks, e1_indicators, e2_indicators, label_arrays) = batch[0]

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        ep_masks = ep_masks.to(device)
        labels = label_arrays.to(device)

        # reset gradient
        optimizer.zero_grad()

        # predict the relationship types between entity pairs
        scores = model_forward(input_ids, attention_mask, ep_masks)

        # binary cross entropy loss
        loss = criterion(scores, labels)
        train_loss += loss.item()

        loss.backward() 

        optimizer.step()

        # track the loss of each training example for debugging
        wandb.log({"batch": batch_idx, "batch_loss": loss.item()})
    
    train_loss /= len(chemdisgene.train)
    wandb.log({"epoch": epoch, "loss": train_loss})

    with torch.no_grad():
        scores, labels = [], []
        # 
        for sample_idx, data in enumerate(chemdisgene.val):
            # 
            input_ids = torch.tensor(data["input"]).to(device)
            attention_mask = torch.tensor(data["pad"]).to(device)

            e1_indicators_ = np.array(data["e1_indicators"])
            e2_indicators_ = np.array(data["e2_indicators"])

            ep_masks_ = []
            for e1_indicator, e2_indicator in list(zip(list(e1_indicators_), list(e2_indicators_))):
                    ep_mask_ = np.full(
                            (512, 512), -1e20)
                    ep_outer = 1 - np.outer(e1_indicator, e2_indicator)
                    ep_mask_ = ep_mask_ * ep_outer
                    ep_masks_.append(ep_mask_)
            ep_masks_ = np.array(ep_masks_)

            ep_masks = torch.tensor(
                    np.array(ep_masks_), dtype=torch.float32).to(device)
            label_array = torch.tensor(
                    np.array(data["label_vectors"]), dtype=torch.float32)
            
            # 
            pairwise = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
            pairwise = pairwise + ep_masks.unsqueeze(0).unsqueeze(4)
            pairwise = torch.logsumexp(pairwise, dim=[2,3])
            score = pairwise[:, :, :-1]

            #
            score = score.detach().cpu().numpy().squeeze(axis=0)
            prediction = (score > np.zeros(14))

            for j in range(len(prediction)):
                predict_names = []
                for k in list(np.where(prediction[j] == 1)[0]):
                        predict_names.append(
                        chemdisgene.relation_name[k])
                label_names = []
                for k in list(np.where(label_array[j] == 1)[0]):
                        label_names.append(chemdisgene.relation_name[k])
                score_dict = {}
                for k, scr in enumerate(list(scores[j])):
                        if k not in chemdisgene.relation_name:
                                score_dict["NA"] = float(scr)
                        else:
                                score_dict[chemdisgene.relation_name[k]] = float(
                                    scr)
            # 
            wandb.log({"epoch": epoch, "sample_idx": sample_idx, "docid": data['docid'],
                       "e1s": data['e1s'], "e2s": data['e2s'], "label_names": label_names,
                       "predictions": predict_names, "scores": score_dict})    
        
        scores = np.concatenate(scores, axis=0)
        labels = np.concatenate(labels, axis=0)

        average_precision = metrics.average_precision_score(
                labels.flatten(), scores.flatten())
        
        predictions = (scores > np.zeros(14))
        predictions_categ = predictions

        results = calculate_metrics(
        predictions, predictions_categ, labels)
        summary_metrics = {
              "average_precision": average_precision,
              "micro_f1":results['micro_f']
        }
        wandb.log({"epoch": 0} | summary_metrics | categ_metrics(results))

In [10]:
import numpy as np
from sklearn import metrics

scores = []
labels = []

for sample_idx, data in enumerate(chemdisgene.val[0:2]):
    # Input data tensors
    input_ids = torch.tensor(data["input"]).to(device)
    attention_mask = torch.tensor(data["pad"]).to(device)

    e1_indicators_ = np.array(data["e1_indicators"])
    e2_indicators_ = np.array(data["e2_indicators"])

    ep_masks_ = []
    for e1_indicator, e2_indicator in list(zip(list(e1_indicators_), list(e2_indicators_))):
        ep_mask_ = np.full((512, 512), -1e20)
        ep_outer = 1 - np.outer(e1_indicator, e2_indicator)
        ep_mask_ = ep_mask_ * ep_outer
        ep_masks_.append(ep_mask_)
    ep_masks_ = np.array(ep_masks_)

    ep_masks = torch.tensor(np.array(ep_masks_), dtype=torch.float32).to(device)
    label_array = torch.tensor(np.array(data["label_vectors"]), dtype=torch.float32)

    # Model prediction
    pairwise = model(input_ids.unsqueeze(0), attention_mask.unsqueeze(0))
    pairwise = pairwise + ep_masks.unsqueeze(0).unsqueeze(4)
    pairwise = torch.logsumexp(pairwise, dim=[2, 3])
    score = pairwise[:, :, :-1]

    # Process and extract predictions
    score = score.detach().cpu().numpy().squeeze(axis=0)
    label = label_array.cpu().numpy()
    prediction = (score > np.zeros(14))

    scores.append(score)
    labels.append(label)

    label_names = []
    predict_names = []
    for j in range(len(prediction)):
        pass

    label_names = []
    predict_name = []
    #for j in range(len(prediction)):
    #    predict_names = []
    #    for k in list(np.where(prediction[j] == 1)[0]):
    #        predict_names.append(chemdisgene.relation_name[k])
    #    label_names = []
    #    for k in list(np.where(label_array[j] == 1)[0]):
    #        label_names.append(chemdisgene.relation_name[k])
    #    score_dict = {}
    #    for k, scr in enumerate(list(score[j])):
    #        if k not in chemdisgene.relation_name:
    #            score_dict["NA"] = float(scr)
    #        else:
    #            score_dict[chemdisgene.relation_name[k]] = float(scr)
    
    # get probabilities from sigmoid 
    probs = 1/(1 + np.exp(-score[0]))

    print({"epoch": 0, "sample_idx": sample_idx, "docid": data['docid'], "e1s": data['e1s'],
            "e2s": data['e2s'], "label_name": label_names, "prediction": predict_names,
            "probabilities": []})
    
scores = np.concatenate(scores, axis=0)
labels = np.concatenate(labels, axis=0)

average_precision = metrics.average_precision_score(
    labels.flatten(), scores.flatten())

predictions = (scores > np.zeros(14))
predictions_categ = predictions

results = calculate_metrics(
    predictions, predictions_categ, labels)
summary_metrics = {
    "average_precision": average_precision,
    "micro_f1":results['micro_f']
}
print({"epoch": 0} | summary_metrics | categ_metrics(results))

{'epoch': 0, 'sample_idx': 0, 'docid': '26583456', 'e1s': ['MESH:D005283'], 'e2s': ['MESH:D012131'], 'label_name': [], 'prediction': [], 'probabilities': []}
{'epoch': 0, 'sample_idx': 1, 'docid': '26766292', 'e1s': ['MESH:D000077185', 'MESH:D000077185', 'MESH:D008315', 'MESH:D008315', 'MESH:D012967', 'MESH:D012967', 'MESH:D006861', 'MESH:D006861', 'MESH:D011794', 'MESH:D011794', 'MESH:D059808', 'MESH:D059808'], 'e2s': ['MESH:D006984', 'MESH:D007249', 'MESH:D006984', 'MESH:D007249', 'MESH:D006984', 'MESH:D007249', 'MESH:D006984', 'MESH:D007249', 'MESH:D006984', 'MESH:D007249', 'MESH:D006984', 'MESH:D007249'], 'label_name': [], 'prediction': [], 'probabilities': []}
{'epoch': 0, 'average_precision': 0.023809523809523808, 'micro_f1': 0.010928961748633882, 'precision_chem_disease:marker/mechanism': 0.07692307692307693, 'precision_chem_disease:therapeutic': 0.0, 'precision_chem_gene:increases^expression': 0.0, 'precision_chem_gene:decreases^expression': 0.0, 'precision_gene_disease:marker/

In [11]:
results

{'micro_p': 0.005494505494505495,
 'micro_r': 1.0,
 'micro_f': 0.010928961748633882,
 'macro_p': 0.005494505494505495,
 'macro_r': 0.07142857142857142,
 'macro_f': 0.010204081632653062,
 'categ_acc': 1.0,
 'categ_macro_p': 0.07142857142857142,
 'categ_macro_r': 0.07142857142857142,
 'categ_macro_f': 0.07142857142857142,
 'na_acc': 0.07692307692307693,
 'not_na_p': 0.07692307692307693,
 'not_na_r': 1.0,
 'not_na_f': 0.14285714285714288,
 'na_p': 0,
 'na_r': 0.0,
 'na_f': 0,
 'per_rel_p': array([0.07692308, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ]),
 'per_rel_r': array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'per_rel_f': array([0.14285714, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        ]),
 'categ_per_rel_p': array([1.

In [11]:
wandb.finish()



0,1
batch,▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█▁█
batch_loss,▅▆▂▁▂▁▁▇▆▁█▄▂▆▃▁▇▂█▄▂▁▅▃▅▁▂▃▄▂▂▂▇▂▂▁▂▂▁▂
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch,1637.0
batch_loss,0.00664
epoch,99.0
loss,0.00124


In [5]:
def calculate_metrics(predictions, predictions_categ, labels):
    # Calculate metrics given prediction and labels
    # predictions: (N, R), does not include NA in R
    # labels: (N, R), one and zeros, does not include NA in R
    # predictions_categ: (N, R), contains predictions for calculating performance of categorical classifier (exclude NA)
    
    TPs = predictions * labels  # (N, R)
    TP = TPs.sum()
    P = predictions.sum()
    T = labels.sum()

    micro_p = TP / P if P != 0 else 0
    micro_r = TP / T if T != 0 else 0
    micro_f = 2 * micro_p * micro_r / (micro_p + micro_r) if micro_p + micro_r > 0 else 0

    categ_TPs = predictions_categ * labels
    categ_TP = categ_TPs.sum()
    # Excludes instances whose label is NA
    categ_Ps = (predictions_categ * (labels.sum(1) > 0)[:, None])

    categ_acc = categ_TP / T if T != 0 else 0

    not_NA_Ps = (predictions.sum(1) > 0)
    not_NA_Ts = (labels.sum(1) > 0)
    not_NA_TPs = not_NA_Ps * not_NA_Ts
    not_NA_P = not_NA_Ps.sum()
    not_NA_T = not_NA_Ts.sum()
    not_NA_TP = not_NA_TPs.sum()
    not_NA_prec = not_NA_TP / not_NA_P if not_NA_P != 0 else 0
    not_NA_recall = not_NA_TP / not_NA_T if not_NA_T != 0 else 0
    not_NA_f = 2 * not_NA_prec * not_NA_recall / (not_NA_prec + not_NA_recall) if not_NA_prec + not_NA_recall > 0 else 0

    not_NA_acc = (not_NA_Ps == not_NA_Ts).mean()

    NA_Ps = (predictions.sum(1) == 0)
    NA_Ts = (labels.sum(1) == 0)
    NA_TPs = NA_Ps * NA_Ts
    NA_P = NA_Ps.sum()
    NA_T = NA_Ts.sum()
    NA_TP = NA_TPs.sum()
    NA_prec = NA_TP / NA_P if NA_P != 0 else 0
    NA_recall = NA_TP / NA_T if NA_T != 0 else 0
    NA_f = 2 * NA_prec * NA_recall / (NA_prec + NA_recall) if NA_prec + NA_recall > 0 else 0

    per_rel_p = np.zeros(predictions.shape[1])
    per_rel_r = np.zeros(predictions.shape[1])
    per_rel_f = np.zeros(predictions.shape[1])
    categ_per_rel_p = np.zeros(predictions.shape[1])
    categ_per_rel_r = np.zeros(predictions.shape[1])
    categ_per_rel_f = np.zeros(predictions.shape[1])
    # Per-relation metrics:
    for i in range(predictions.shape[1]):
        TP_ = TPs[:, i].sum()
        P_ = predictions[:, i].sum()
        T_ = labels[:, i].sum()
        categ_TP_ = categ_TPs[:, i].sum()
        categ_P_ = categ_Ps[:, i].sum()

        # If no such relation in the test data, recall = 0
        per_rel_r[i] = TP_ / T_ if T_ != 0 else 0
        categ_per_rel_r[i] = categ_TP_ / T_ if T_ != 0 else 0

        # If no such relation in the prediction, precision = 0
        per_rel_p[i] = TP_ / P_ if P_ != 0 else 0

        # If no such relation in the prediction, precision = 0
        categ_per_rel_p[i] = categ_TP_ / categ_P_ if categ_P_ != 0 else 0

        per_rel_f[i] = 2 * per_rel_p[i] * per_rel_r[i] / (per_rel_p[i] + per_rel_r[i]) if per_rel_p[i] + per_rel_r[i] > 0 else 0

        categ_per_rel_f[i] = 2 * categ_per_rel_p[i] * categ_per_rel_r[i] / (categ_per_rel_p[i] + categ_per_rel_r[i]) if categ_per_rel_p[i] + categ_per_rel_r[i] > 0 else 0

    macro_p = per_rel_p.mean()
    macro_r = per_rel_r.mean()
    macro_f = per_rel_f.mean()

    categ_macro_p = categ_per_rel_p.mean()
    categ_macro_r = categ_per_rel_r.mean()
    categ_macro_f = categ_per_rel_f.mean()

    results = {
        "micro_p": micro_p,
        "micro_r": micro_r,
        "micro_f": micro_f,
        "macro_p": macro_p,
        "macro_r": macro_r,
        "macro_f": macro_f,
        "categ_acc": categ_acc,
        "categ_macro_p": categ_macro_p,
        "categ_macro_r": categ_macro_r,
        "categ_macro_f": categ_macro_f,
        "na_acc": not_NA_acc,
        "not_na_p": not_NA_prec,
        "not_na_r": not_NA_recall,
        "not_na_f": not_NA_f,
        "na_p": NA_prec,
        "na_r": NA_recall,
        "na_f": NA_f,
        "per_rel_p": per_rel_p,
        "per_rel_r": per_rel_r,
        "per_rel_f": per_rel_f,
        "categ_per_rel_p": categ_per_rel_p,
        "categ_per_rel_r": categ_per_rel_r,
        "categ_per_rel_f": categ_per_rel_f,
    }

    return results


In [6]:
def categ_metrics(results): 
    return {
    "precision_chem_disease:marker/mechanism": results['per_rel_p'][0],
    "precision_chem_disease:therapeutic": results['per_rel_p'][1],
    "precision_chem_gene:increases^expression": results['per_rel_p'][2],
    "precision_chem_gene:decreases^expression": results['per_rel_p'][3],
    "precision_gene_disease:marker/mechanism": results['per_rel_p'][4],
    "precision_chem_gene:increases^activity": results['per_rel_p'][5],
    "precision_chem_gene:decreases^activity": results['per_rel_p'][6],
    "precision_chem_gene:increases^metabolic_processing": results['per_rel_p'][7],
    "precision_chem_gene:affects^binding": results['per_rel_p'][8],
    "precision_chem_gene:increases^transport": results['per_rel_p'][9],
    "precision_chem_gene:decreases^metabolic_processing": results['per_rel_p'][10],
    "precision_chem_gene:affects^localization": results['per_rel_p'][11],
    "precision_chem_gene:affects^expression": results['per_rel_p'][12],
    "precision_gene_disease:therapeutic": results['per_rel_p'][13],
    "recall_chem_disease:marker/mechanism": results['per_rel_r'][0],
    "recall_chem_disease:therapeutic": results['per_rel_r'][1],
    "recall_chem_gene:increases^expression": results['per_rel_r'][2],
    "recall_chem_gene:decreases^expression": results['per_rel_r'][3],
    "recall_gene_disease:marker/mechanism": results['per_rel_r'][4],
    "recall_chem_gene:increases^activity": results['per_rel_r'][5],
    "recall_chem_gene:decreases^activity": results['per_rel_r'][6],
    "recall_chem_gene:increases^metabolic_processing": results['per_rel_r'][7],
    "recall_chem_gene:affects^binding": results['per_rel_r'][8],
    "recall_chem_gene:increases^transport": results['per_rel_r'][9],
    "recall_chem_gene:decreases^metabolic_processing": results['per_rel_r'][10],
    "recall_chem_gene:affects^localization": results['per_rel_r'][11],
    "recall_chem_gene:affects^expression": results['per_rel_r'][12],
    "recall_gene_disease:therapeutic": results['per_rel_r'][13]}