In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
import torch.optim as optim
import torch.nn.functional as F
import string
!pip install clean-text
from cleantext import clean

Collecting clean-text
  Downloading clean_text-0.6.0-py3-none-any.whl.metadata (6.6 kB)
Collecting emoji<2.0.0,>=1.0.0 (from clean-text)
  Downloading emoji-1.7.0.tar.gz (175 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.4/175.4 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting ftfy<7.0,>=6.0 (from clean-text)
  Downloading ftfy-6.2.0-py3-none-any.whl.metadata (7.3 kB)
Downloading clean_text-0.6.0-py3-none-any.whl (11 kB)
Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: emoji
  Building wheel for emoji (setup.py) ... [?25l- done
[?25h  Created wheel for emoji: filename=emoji-1.7.0-py3-none-any.whl size=171033 sha256=891ba842770d96d851e3f23f267115e6ecbc0943cc17d969d873aad5e0a8c4f5
  Stored in directory: /root/.cache

In [2]:
# Load the dataset from Hugging Face
dataset = load_dataset("123rc/medical_text")

Downloading readme:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 14.3M/14.3M [00:00<00:00, 61.2MB/s]
Downloading data: 100%|██████████| 3.59M/3.59M [00:00<00:00, 5.88MB/s]


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [3]:
train_data = dataset["train"]
test_data = dataset["test"]

In [4]:
# Convert train and test splits to DataFrames
train_df = pd.DataFrame(train_data)
test_df = pd.DataFrame(test_data)

In [5]:
# Rename two columns 
train_df.rename(columns={'condition_label': 'label', 'medical_abstract': 'text'}, inplace=True)

In [6]:
# Rename two columns 
test_df.rename(columns={'condition_label': 'label', 'medical_abstract': 'text'}, inplace=True)

In [7]:
train_df['label'] = train_df['label'] - 1
test_df['label'] = test_df['label'] - 1

In [8]:
train_df['label'].value_counts()

label
4    3844
0    2530
3    2441
2    1540
1    1195
Name: count, dtype: int64

In [9]:
train_df['label'].nunique()

5

In [10]:
test_df.head()

Unnamed: 0,label,text
0,2,Obstructive sleep apnea following topical orop...
1,4,Neutrophil function and pyogenic infections in...
2,4,A phase II study of combined methotrexate and ...
3,0,Flow cytometric DNA analysis of parathyroid tu...
4,3,Paraneoplastic vasculitic neuropathy: a treata...


In [11]:
train_df.shape

(11550, 2)

In [12]:
test_df.shape

(2888, 2)

In [13]:
# Removing Repeated Punctuations
def remove_repeated_punctuation(text):
    punctuations = set(string.punctuation)
    cleaned_text = []
    for char in text:
        if char in punctuations:
            punctuations.remove(char)
            cleaned_text.append(char)
        elif char not in punctuations:
            punctuations = set(string.punctuation)
            cleaned_text.append(char)
    return ''.join(cleaned_text)

# Apply the remove_repeated_punctuation function to the 'review' column
train_df['text'] = train_df['text'].apply(remove_repeated_punctuation)

train_df.head()

Unnamed: 0,label,text
0,4,Tissue changes around loose prostheses. A cani...
1,0,Neuropeptide Y and neuron-specific enolase lev...
2,1,"Sexually transmitted diseases of the colon, re..."
3,0,Lipolytic factors associated with murine and h...
4,2,Does carotid restenosis predict an increased r...


In [14]:
test_df['text'] = test_df['text'].apply(remove_repeated_punctuation)

test_df.head()

Unnamed: 0,label,text
0,2,Obstructive sleep apnea following topical orop...
1,4,Neutrophil function and pyogenic infections in...
2,4,A phase II study of combined methotrexate and ...
3,0,Flow cytometric DNA analysis of parathyroid tu...
4,3,Paraneoplastic vasculitic neuropathy: a treata...


In [15]:
import nltk
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
import pandas as pd

# Download NLTK resources (run only once)
nltk.download('punkt')

# Initialize the PorterStemmer
porter = PorterStemmer()
# Function to stem text
def stem_text(text):
    words = word_tokenize(text)
    stemmed_words = [porter.stem(word) for word in words]
    stemmed_text = ' '.join(stemmed_words)
    return stemmed_text

# Apply stemming to the 'text' column
train_df['text'] = train_df['text'].apply(stem_text)
test_df['text'] = test_df['text'].apply(stem_text)

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [16]:
# NLTK library to remove Stopwords.
from nltk.corpus import stopwords

In [17]:
stopword = stopwords.words('english')

In [18]:
# Store the length of each review before removing less important words
train_df['length_before'] = train_df['text'].apply(len)

In [19]:
# Function
def remove_stopwords(text):
    new_text = []
    
    for word in text.split():
        if word in stopword:
            new_text.append('')
        else:
            new_text.append(word)
    x = new_text[:]
    new_text.clear()
    return " ".join(x)

# Calling Function 
train_df['text'] = train_df['text'].apply(remove_stopwords)

train_df.head()

Unnamed: 0,label,text,length_before
0,4,tissu chang around loos prosthes . A canin mod...,907
1,0,neuropeptid Y neuron-specif enolas level ben...,1118
2,1,"sexual transmit diseas colon , rectum , anu...",1595
3,0,lipolyt factor associ murin human cancer cac...,908
4,2,doe carotid restenosi predict increas risk l...,1371


In [20]:
# Store the length of each review before removing less important words
train_df['length_after'] = train_df['text'].apply(len)
train_df.head()

Unnamed: 0,label,text,length_before,length_after
0,4,tissu chang around loos prosthes . A canin mod...,907,737
1,0,neuropeptid Y neuron-specif enolas level ben...,1118,912
2,1,"sexual transmit diseas colon , rectum , anu...",1595,1339
3,0,lipolyt factor associ murin human cancer cac...,908,782
4,2,doe carotid restenosi predict increas risk l...,1371,1177


In [21]:
# Calling Function 
test_df['text'] = test_df['text'].apply(remove_stopwords)

test_df.head()

Unnamed: 0,label,text
0,2,obstruct sleep apnea follow topic oropharyng a...
1,4,neutrophil function pyogen infect bone marro...
2,4,A phase II studi combin methotrex teniposid ...
3,0,flow cytometr dna analysi parathyroid tumor ....
4,3,paraneoplast vasculit neuropathi : treatabl n...


In [22]:
texts = train_df['text'].tolist()
labels = train_df['label'].tolist()

In [23]:
train_texts, val_texts, train_labels, val_labelss = train_test_split(texts, labels, test_size=0.2, random_state=42)

In [24]:
test_texts = test_df['text'].tolist()
test_labels = test_df['label'].tolist()

In [25]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = [str(text) for text in texts]
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

In [26]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [27]:
# Set up parameters
bert_model_name = 'bert-base-uncased'
num_classes = 5
max_length = 512
batch_size = 8

In [28]:
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labelss, tokenizer, max_length)
test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, max_length)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [30]:
epochs = 20
best_roc_auc = 0.0
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 3
gradient_accumulation_steps = 10

# Set the optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_dataloader) * epochs // gradient_accumulation_steps
)


In [31]:
# Training
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
            optimizer.step()
            scheduler.step()
            
    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            
    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_dataloader)
    train_loss /= len(train_dataloader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1: {f1}, Micro F1: {micro_f1:.4f}, Macro Roc Auc: {macro_roc_auc:.4f}')
            
   # Implement early stopping
    if epoch > 0 and macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0


Epoch: 1/20, Training Loss: 1.4456, Validation Loss: 1.1377
Accuracy: 0.5576, Recall: 0.5576, Precision: 0.5098, F1: 0.5046102433757993, Micro F1: 0.5576, Macro Roc Auc: 0.8083


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2/20, Training Loss: 1.0339, Validation Loss: 0.9380
Accuracy: 0.6104, Recall: 0.6104, Precision: 0.5990, F1: 0.5970489826339762, Micro F1: 0.6104, Macro Roc Auc: 0.8681
Epoch: 3/20, Training Loss: 0.9013, Validation Loss: 0.9094
Accuracy: 0.6212, Recall: 0.6212, Precision: 0.6152, F1: 0.6072267832008091, Micro F1: 0.6212, Macro Roc Auc: 0.8804
Epoch: 4/20, Training Loss: 0.8320, Validation Loss: 0.9017
Accuracy: 0.6221, Recall: 0.6221, Precision: 0.6238, F1: 0.6085338498884247, Micro F1: 0.6221, Macro Roc Auc: 0.8854
Epoch: 5/20, Training Loss: 0.7808, Validation Loss: 0.8852
Accuracy: 0.6242, Recall: 0.6242, Precision: 0.6267, F1: 0.6091505393010244, Micro F1: 0.6242, Macro Roc Auc: 0.8890
Epoch: 6/20, Training Loss: 0.7357, Validation Loss: 0.9155
Accuracy: 0.6234, Recall: 0.6234, Precision: 0.6343, F1: 0.6061019588876609, Micro F1: 0.6234, Macro Roc Auc: 0.8826
EarlyStopping counter: 1 out of 3
Epoch: 7/20, Training Loss: 0.6985, Validation Loss: 0.9153
Accuracy: 0.6130, Rec

In [32]:
model.eval()

test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())


In [33]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

test_preds_class = np.argmax(test_preds, axis=1)

report = classification_report(test_labels, test_preds_class, digits = 4)

print(report)

              precision    recall  f1-score   support

           0     0.6844    0.7536    0.7173       633
           1     0.5180    0.5786    0.5466       299
           2     0.5738    0.6364    0.6034       385
           3     0.6574    0.7328    0.6930       610
           4     0.5453    0.4256    0.4781       961

    accuracy                         0.6063      2888
   macro avg     0.5958    0.6254    0.6077      2888
weighted avg     0.6004    0.6063    0.5997      2888

