In [2]:
import os 
import pandas as pd
import numpy as np

from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support,classification_report

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizerFast, BertForSequenceClassification, AdamW

2024-03-16 20:06:54.269459: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-16 20:06:54.269505: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-16 20:06:54.270332: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-16 20:06:54.276555: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
DATA_PATH = '/data/02_training_data/'
MODEL_OUT_PATH = '/data/03_models/relation_classifier/'

## Load Data

In [4]:
try:
    train_df = pd.read_parquet(os.path.join(DATA_PATH,'temp_train_df.pq'))
    test_df = pd.read_parquet(os.path.join(DATA_PATH,'temp_test_df.pq'))
except FileNotFoundError as e:
    print(e)

### Filter to relevant columns

In [5]:
train_df = train_df[['type_of_regulation', 'text_prep', 'srna_name_mentioned', 'gene_name_mentioned']].drop_duplicates()

In [6]:
test_df = test_df[['type_of_regulation', 'text_prep', 'srna_name_mentioned', 'gene_name_mentioned']].drop_duplicates()

### Encode Labels

In [7]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
train_df['labels']=le.fit_transform(train_df['type_of_regulation'])
test_df['labels'] = le.transform(test_df['type_of_regulation'])

### Load Tokenizer

In [8]:
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.2")

### Prepare Datasets for Training

In [9]:
def preprocess_texts(dataframe, tokenizer, max_len=128):
    input_ids = []
    attention_masks = []
    labels = []

    for _, row in dataframe.iterrows():
        # Mark entities in the text
        marked_text = f"{row['text_prep']} [SEP] {row['srna_name_mentioned']} [SEP] {row['gene_name_mentioned']}"
        encoding = tokenizer.encode_plus(
            marked_text,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoding['input_ids'])
        attention_masks.append(encoding['attention_mask'])
        labels.append(row['labels'])

    # Convert lists to tensors
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)

    return input_ids, attention_masks, labels

In [10]:
# Preprocess the dataset
input_ids, attention_masks, labels = preprocess_texts(train_df, tokenizer)

In [11]:
validation_input_ids, validation_attention_masks, validation_labels = preprocess_texts(test_df, tokenizer)

validation_data = TensorDataset(validation_input_ids, validation_attention_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=32)

### Load Model

In [12]:
num_labels = len(set(train_df["type_of_regulation"]))

model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.2", num_labels=num_labels) # Set num_labels to your number of relation types

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Train Model

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs=20
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

# Convert your dataset to a DataLoader
# Assume `input_ids`, `attention_masks`, and `labels` are your full dataset tensors
dataset = TensorDataset(input_ids, attention_masks, labels)
train_dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=4)

# Training loop
model.train()
for epoch in range(epochs):  # Number of epochs
    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        optimizer.zero_grad()
        
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        print(f"Step {step} of epoch {epoch} completed.")



Step 0 of epoch 0 completed.
Step 1 of epoch 0 completed.
Step 2 of epoch 0 completed.
Step 3 of epoch 0 completed.
Step 4 of epoch 0 completed.
Step 5 of epoch 0 completed.
Step 6 of epoch 0 completed.
Step 7 of epoch 0 completed.
Step 8 of epoch 0 completed.
Step 9 of epoch 0 completed.
Step 10 of epoch 0 completed.
Step 11 of epoch 0 completed.
Step 12 of epoch 0 completed.
Step 13 of epoch 0 completed.
Step 14 of epoch 0 completed.
Step 15 of epoch 0 completed.
Step 16 of epoch 0 completed.
Step 17 of epoch 0 completed.
Step 0 of epoch 1 completed.
Step 1 of epoch 1 completed.
Step 2 of epoch 1 completed.
Step 3 of epoch 1 completed.
Step 4 of epoch 1 completed.
Step 5 of epoch 1 completed.
Step 6 of epoch 1 completed.
Step 7 of epoch 1 completed.
Step 8 of epoch 1 completed.
Step 9 of epoch 1 completed.
Step 10 of epoch 1 completed.
Step 11 of epoch 1 completed.
Step 12 of epoch 1 completed.
Step 13 of epoch 1 completed.
Step 14 of epoch 1 completed.
Step 15 of epoch 1 completed.


### Evaluate Model

In [17]:
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [18]:
predictions, true_labels = [], []

for batch in validation_dataloader:
    # Unpack the batch data and move to the device
    b_input_ids, b_attention_mask, b_labels = [t.to(device) for t in batch]

    with torch.no_grad():
        # Forward pass
        outputs = model(input_ids=b_input_ids, attention_mask=b_attention_mask)

    # Move logits to CPU
    logits = outputs.logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    # Store predictions and true labels
    batch_predictions = np.argmax(logits, axis=1)
    predictions.extend(batch_predictions)
    true_labels.extend(label_ids)

# Generate the classification report
report = classification_report(true_labels, predictions, target_names=le.classes_)

print(report)

                               precision    recall  f1-score   support

                 activator of       0.00      0.00      0.00         1
       antisense inhibitor of       1.00      0.88      0.94        17
regulates (molecular biology)       0.73      1.00      0.84         8

                     accuracy                           0.88        26
                    macro avg       0.58      0.63      0.59        26
                 weighted avg       0.88      0.88      0.87        26



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
