In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
import torch.optim as optim
import torch.nn.functional as F
import string
!pip install clean-text
from cleantext import clean

Collecting clean-text
  Downloading clean_text-0.6.0-py3-none-any.whl.metadata (6.6 kB)
Collecting emoji<2.0.0,>=1.0.0 (from clean-text)
  Downloading emoji-1.7.0.tar.gz (175 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.4/175.4 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting ftfy<7.0,>=6.0 (from clean-text)
  Downloading ftfy-6.2.0-py3-none-any.whl.metadata (7.3 kB)
Downloading clean_text-0.6.0-py3-none-any.whl (11 kB)
Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: emoji
  Building wheel for emoji (setup.py) ... [?25l- done
[?25h  Created wheel for emoji: filename=emoji-1.7.0-py3-none-any.whl size=171033 sha256=e71edff7d2c9469848970a1ebe5193730278ff9907b4121e244e92efb26e0daf
  Stored in directory: /root/.cache

In [2]:
# Load the dataset from Hugging Face
dataset = load_dataset("123rc/medical_text")

Downloading readme:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 14.3M/14.3M [00:00<00:00, 57.7MB/s]
Downloading data: 100%|██████████| 3.59M/3.59M [00:00<00:00, 10.0MB/s]


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [3]:
train_data = dataset["train"]
test_data = dataset["test"]

In [4]:
# Convert train and test splits to DataFrames
train_df = pd.DataFrame(train_data)
test_df = pd.DataFrame(test_data)

In [5]:
# Rename two columns 
train_df.rename(columns={'condition_label': 'label', 'medical_abstract': 'text'}, inplace=True)

In [6]:
# Rename two columns 
test_df.rename(columns={'condition_label': 'label', 'medical_abstract': 'text'}, inplace=True)

In [7]:
train_df['label'] = train_df['label'] - 1
test_df['label'] = test_df['label'] - 1

In [8]:
train_df['label'].value_counts()

label
4    3844
0    2530
3    2441
2    1540
1    1195
Name: count, dtype: int64

In [9]:
train_df['label'].nunique()

5

In [10]:
test_df.head()

Unnamed: 0,label,text
0,2,Obstructive sleep apnea following topical orop...
1,4,Neutrophil function and pyogenic infections in...
2,4,A phase II study of combined methotrexate and ...
3,0,Flow cytometric DNA analysis of parathyroid tu...
4,3,Paraneoplastic vasculitic neuropathy: a treata...


In [11]:
train_df.shape

(11550, 2)

In [12]:
test_df.shape

(2888, 2)

In [13]:
import nltk
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
import pandas as pd

# Download NLTK resources (run only once)
nltk.download('punkt')

# Initialize the PorterStemmer
porter = PorterStemmer()
# Function to stem text
def stem_text(text):
    words = word_tokenize(text)
    stemmed_words = [porter.stem(word) for word in words]
    stemmed_text = ' '.join(stemmed_words)
    return stemmed_text


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [14]:
# Apply stemming to the 'text' column
train_df['text'] = train_df['text'].apply(stem_text)
test_df['text'] = test_df['text'].apply(stem_text)

In [15]:
# NLTK library to remove Stopwords.
from nltk.corpus import stopwords

In [16]:
stopword = stopwords.words('english')

In [17]:
# Store the length of each review before removing less important words
train_df['length_before'] = train_df['text'].apply(len)

In [18]:
# Function
def remove_stopwords(text):
    new_text = []
    
    for word in text.split():
        if word in stopword:
            new_text.append('')
        else:
            new_text.append(word)
    x = new_text[:]
    new_text.clear()
    return " ".join(x)

# Calling Function 
train_df['text'] = train_df['text'].apply(remove_stopwords)

train_df.head()

Unnamed: 0,label,text,length_before
0,4,tissu chang around loos prosthes . A canin mod...,907
1,0,neuropeptid Y neuron-specif enolas level ben...,1118
2,1,"sexual transmit diseas colon , rectum , anu...",1595
3,0,lipolyt factor associ murin human cancer cac...,908
4,2,doe carotid restenosi predict increas risk l...,1371


In [19]:
# Store the length of each review before removing less important words
train_df['length_after'] = train_df['text'].apply(len)
train_df.head()

Unnamed: 0,label,text,length_before,length_after
0,4,tissu chang around loos prosthes . A canin mod...,907,737
1,0,neuropeptid Y neuron-specif enolas level ben...,1118,912
2,1,"sexual transmit diseas colon , rectum , anu...",1595,1339
3,0,lipolyt factor associ murin human cancer cac...,908,782
4,2,doe carotid restenosi predict increas risk l...,1371,1177


In [20]:
# Calling Function 
test_df['text'] = test_df['text'].apply(remove_stopwords)

test_df.head()

Unnamed: 0,label,text
0,2,obstruct sleep apnea follow topic oropharyng a...
1,4,neutrophil function pyogen infect bone marro...
2,4,A phase II studi combin methotrex teniposid ...
3,0,flow cytometr dna analysi parathyroid tumor ....
4,3,paraneoplast vasculit neuropathi : treatabl n...


In [21]:
# Apply text cleaning function
train_df['text'] = train_df['text'].apply(lambda x: clean(x,
                                                        fix_unicode=True,
                                                        to_ascii=True,
                                                        lower=True,
                                                        no_line_breaks=False,
                                                        no_urls=True,
                                                        no_emails=True,
                                                        no_phone_numbers=True,
                                                        no_numbers=True,
                                                        no_currency_symbols=True,
                                                        no_punct=True,
                                                        replace_with_punct="",
                                                        replace_with_url="<URL>",
                                                        replace_with_email="<EMAIL>",
                                                        replace_with_phone_number="<PHONE>",
                                                        replace_with_number="<NUMBER>",
                                                        replace_with_currency_symbol="<CUR>",
                                                        lang="en"
                                                        ))

train_df.head()

Unnamed: 0,label,text,length_before,length_after
0,4,tissu chang around loos prosthes a canin model...,907,737
1,0,neuropeptid y neuronspecif enolas level benign...,1118,912
2,1,sexual transmit diseas colon rectum anu challe...,1595,1339
3,0,lipolyt factor associ murin human cancer cache...,908,782
4,2,doe carotid restenosi predict increas risk lat...,1371,1177


In [22]:
# Apply text cleaning function
test_df['text'] = test_df['text'].apply(lambda x: clean(x,
                                                        fix_unicode=True,
                                                        to_ascii=True,
                                                        lower=True,
                                                        no_line_breaks=False,
                                                        no_urls=True,
                                                        no_emails=True,
                                                        no_phone_numbers=True,
                                                        no_numbers=True,
                                                        no_currency_symbols=True,
                                                        no_punct=True,
                                                        replace_with_punct="",
                                                        replace_with_url="<URL>",
                                                        replace_with_email="<EMAIL>",
                                                        replace_with_phone_number="<PHONE>",
                                                        replace_with_number="<NUMBER>",
                                                        replace_with_currency_symbol="<CUR>",
                                                        lang="en"
                                                        ))

test_df.head()

Unnamed: 0,label,text
0,2,obstruct sleep apnea follow topic oropharyng a...
1,4,neutrophil function pyogen infect bone marrow ...
2,4,a phase ii studi combin methotrex teniposid in...
3,0,flow cytometr dna analysi parathyroid tumor im...
4,3,paraneoplast vasculit neuropathi treatabl neur...


In [23]:
texts = train_df['text'].tolist()
labels = train_df['label'].tolist()

In [24]:
train_texts, val_texts, train_labels, val_labelss = train_test_split(texts, labels, test_size=0.2, random_state=42)

In [25]:
test_texts = test_df['text'].tolist()
test_labels = test_df['label'].tolist()

In [26]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = [str(text) for text in texts]
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

In [27]:
class BioBERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BioBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [28]:
# Set up parameters
bert_model_name = 'dmis-lab/biobert-base-cased-v1.2'
num_classes = 5
max_length = 512
batch_size = 8

In [29]:
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labelss, tokenizer, max_length)
test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, max_length)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BioBERTClassifier(bert_model_name, num_classes).to(device)

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


In [31]:
epochs = 20
best_roc_auc = 0.0
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 3
gradient_accumulation_steps = 10

# Set the optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_dataloader) * epochs // gradient_accumulation_steps
)


In [32]:
# Training
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
            optimizer.step()
            scheduler.step()
            
    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            
    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_dataloader)
    train_loss /= len(train_dataloader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1: {f1}, Micro F1: {micro_f1:.4f}, Macro Roc Auc: {macro_roc_auc:.4f}')
            
   # Implement early stopping
    if epoch > 0 and macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0


Epoch: 1/20, Training Loss: 1.3479, Validation Loss: 0.9739
Accuracy: 0.6221, Recall: 0.6221, Precision: 0.6160, F1: 0.5943614248216701, Micro F1: 0.6221, Macro Roc Auc: 0.8736
Epoch: 2/20, Training Loss: 0.8939, Validation Loss: 0.8269
Accuracy: 0.6481, Recall: 0.6481, Precision: 0.6445, F1: 0.6344018548760889, Micro F1: 0.6481, Macro Roc Auc: 0.8993
Epoch: 3/20, Training Loss: 0.7976, Validation Loss: 0.8380
Accuracy: 0.6446, Recall: 0.6446, Precision: 0.6628, F1: 0.6194734795027415, Micro F1: 0.6446, Macro Roc Auc: 0.9005
Epoch: 4/20, Training Loss: 0.7439, Validation Loss: 0.8170
Accuracy: 0.6463, Recall: 0.6463, Precision: 0.6408, F1: 0.6400140315104381, Micro F1: 0.6463, Macro Roc Auc: 0.9028
Epoch: 5/20, Training Loss: 0.6975, Validation Loss: 0.8253
Accuracy: 0.6385, Recall: 0.6385, Precision: 0.6449, F1: 0.623279685404031, Micro F1: 0.6385, Macro Roc Auc: 0.9028
EarlyStopping counter: 1 out of 3
Epoch: 6/20, Training Loss: 0.6611, Validation Loss: 0.8570
Accuracy: 0.6199, Reca

In [33]:
model.eval()

test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())


In [34]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

test_preds_class = np.argmax(test_preds, axis=1)

report = classification_report(test_labels, test_preds_class, digits = 4)

print(report)

              precision    recall  f1-score   support

           0     0.6796    0.8009    0.7353       633
           1     0.5201    0.6488    0.5774       299
           2     0.5698    0.6364    0.6012       385
           3     0.6690    0.7852    0.7225       610
           4     0.5987    0.3881    0.4710       961

    accuracy                         0.6226      2888
   macro avg     0.6074    0.6519    0.6215      2888
weighted avg     0.6193    0.6226    0.6104      2888

