In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import pandas as pd
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
import torch.optim as optim
import torch.nn.functional as F
import string
!pip install clean-text
from cleantext import clean

Collecting clean-text
  Downloading clean_text-0.6.0-py3-none-any.whl.metadata (6.6 kB)
Collecting emoji<2.0.0,>=1.0.0 (from clean-text)
  Downloading emoji-1.7.0.tar.gz (175 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.4/175.4 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting ftfy<7.0,>=6.0 (from clean-text)
  Downloading ftfy-6.2.0-py3-none-any.whl.metadata (7.3 kB)
Downloading clean_text-0.6.0-py3-none-any.whl (11 kB)
Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: emoji
  Building wheel for emoji (setup.py) ... [?25l- done
[?25h  Created wheel for emoji: filename=emoji-1.7.0-py3-none-any.whl size=171033 sha256=18ec7afdac0a6aac3fdb8b97e03df2db1f2be7e375a8c152989e2a9693383f0d
  Stored in directory: /root/.cache

In [2]:
# Load the dataset from Hugging Face
dataset = load_dataset("123rc/medical_text")

Downloading readme:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 14.3M/14.3M [00:00<00:00, 43.0MB/s]
Downloading data: 100%|██████████| 3.59M/3.59M [00:00<00:00, 19.5MB/s]


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [3]:
train_data = dataset["train"]
test_data = dataset["test"]

In [4]:
# Convert train and test splits to DataFrames
train_df = pd.DataFrame(train_data)
test_df = pd.DataFrame(test_data)

In [5]:
# Rename two columns 
train_df.rename(columns={'condition_label': 'label', 'medical_abstract': 'text'}, inplace=True)

In [6]:
# Rename two columns 
test_df.rename(columns={'condition_label': 'label', 'medical_abstract': 'text'}, inplace=True)

In [7]:
train_df['label'] = train_df['label'] - 1
test_df['label'] = test_df['label'] - 1

In [8]:
train_df['label'].value_counts()

label
4    3844
0    2530
3    2441
2    1540
1    1195
Name: count, dtype: int64

In [9]:
train_df['label'].nunique()

5

In [10]:
test_df.head()

Unnamed: 0,label,text
0,2,Obstructive sleep apnea following topical orop...
1,4,Neutrophil function and pyogenic infections in...
2,4,A phase II study of combined methotrexate and ...
3,0,Flow cytometric DNA analysis of parathyroid tu...
4,3,Paraneoplastic vasculitic neuropathy: a treata...


In [11]:
train_df.shape

(11550, 2)

In [12]:
test_df.shape

(2888, 2)

In [13]:
import string
# Removing Repeated Punctuations
def remove_repeated_punctuation(text):
    punctuations = set(string.punctuation)
    cleaned_text = []
    for char in text:
        if char in punctuations:
            punctuations.remove(char)
            cleaned_text.append(char)
        elif char not in punctuations:
            punctuations = set(string.punctuation)
            cleaned_text.append(char)
    return ''.join(cleaned_text)

# Apply the remove_repeated_punctuation function to the 'review' column
train_df['text'] = train_df['text'].apply(remove_repeated_punctuation)
test_df['text'] = test_df['text'].apply(remove_repeated_punctuation)

In [14]:
pip install clean-text

Note: you may need to restart the kernel to use updated packages.


In [15]:
from cleantext import clean
# Define the cleaning function
def clean_text(text):
    return clean(text,
        fix_unicode=True,
        to_ascii=True,
        lower=True,
        no_line_breaks=False,
        no_urls=True,
        no_emails=True,
        no_phone_numbers=True,
        no_numbers=True,
        no_currency_symbols=True,
        no_punct=True,
        replace_with_punct="",
        replace_with_url="<URL>",
        replace_with_email="<EMAIL>",
        replace_with_phone_number="<PHONE>",
        replace_with_number="<NUMBER>",
        replace_with_currency_symbol="<CUR>",
        lang="en"
    )

# Apply the cleaning function to the 'text' column
train_df['text'] = train_df['text'].apply(clean_text)
test_df['text'] = test_df['text'].apply(clean_text)

In [16]:
from sklearn.feature_extraction.text import TfidfVectorizer
# Initialize TF-IDF Vectorizer
tfidf = TfidfVectorizer()

# Fit and transform the text data
tfidf_matrix = tfidf.fit_transform(train_df['text'])

# Get feature names (words)
feature_names = tfidf.get_feature_names_out()

# Create a DataFrame of TF-IDF scores
tfidf_df = pd.DataFrame(tfidf_matrix.toarray(), columns=feature_names)

# Find mean TF-IDF score for each word across all documents
word_scores = tfidf_df.mean(axis=0)

# Set a threshold to identify less important words (adjust threshold as needed)
threshold = 0.00004 # For example, you can set a threshold value

# Get less important words based on threshold
less_important_words = word_scores[word_scores < threshold]

# Display words that will be removed
print("Less important words:")
print(less_important_words)

Less important words:
000g                 0.000008
0c                   0.000008
0wcm2                0.000014
0x                   0.000006
10                   0.000009
                       ...   
zymogens             0.000014
zymograms            0.000012
zymosanactivated     0.000021
zymosanstimulated    0.000006
zzygos               0.000015
Length: 25623, dtype: float64


In [17]:
texts = train_df['text'].tolist()
labels = train_df['label'].tolist()

In [18]:
train_texts, val_texts, train_labels, val_labelss = train_test_split(texts, labels, test_size=0.2, random_state=42)

In [19]:
test_texts = test_df['text'].tolist()
test_labels = test_df['label'].tolist()

In [20]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = [str(text) for text in texts]
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

In [21]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [22]:
# Set up parameters
bert_model_name = 'bert-base-uncased'
num_classes = 5
max_length = 512
batch_size = 8

In [23]:
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labelss, tokenizer, max_length)
test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, max_length)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [25]:
epochs = 20
best_roc_auc = 0.0
min_delta = 0.0001
early_stopping_count = 0
early_stopping_patience = 3
gradient_accumulation_steps = 10

# Set the optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

# Set the scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=50, 
    num_training_steps=len(train_dataloader) * epochs // gradient_accumulation_steps
)


In [26]:
import time  # Import the time module

In [27]:
# Start training
start_time = time.time()
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad() if step % gradient_accumulation_steps == 0 else None
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        (loss / gradient_accumulation_steps).backward()
        train_loss += loss.item()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
            optimizer.step()
            scheduler.step()

    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            val_loss += loss.item()
            val_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
            val_labels.append(labels.cpu().numpy())
            

    val_preds = np.concatenate(val_preds)
    val_labels = np.concatenate(val_labels)
    val_loss /= len(val_dataloader)
    train_loss /= len(train_dataloader)
    print(f'Epoch: {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

    # Calculate metrics
    val_preds_class = np.argmax(val_preds, axis=1)
    accuracy = accuracy_score(val_labels, val_preds_class)
    recall = recall_score(val_labels, val_preds_class, average='weighted')
    precision = precision_score(val_labels, val_preds_class, average='weighted')
    f1 = f1_score(val_labels, val_preds_class, average='weighted')
    micro_f1 = f1_score(val_labels, val_preds_class, average='micro')
    macro_roc_auc = roc_auc_score(val_labels, val_preds, multi_class='ovo', average='macro')

    print(f'Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1: {f1}, Micro F1: {micro_f1:.4f}, Macro Roc Auc: {macro_roc_auc:.4f}')
         
    # Implement early stopping
    if epoch > 0 and macro_roc_auc - best_roc_auc < min_delta:
        early_stopping_count += 1
        print(f'EarlyStopping counter: {early_stopping_count} out of {early_stopping_patience}')
        if early_stopping_count >= early_stopping_patience:
            print('Early stopping')
            break
    else:
        best_roc_auc = macro_roc_auc
        early_stopping_count = 0
        
training_time = time.time() - start_time
print(f"\\Total training time: {training_time} seconds")

# Save the model to the current working directory
torch.save(model.state_dict(), "ctc_bert_pipeline.pth")     

Epoch: 1/20, Training Loss: 1.5036603306795095, Validation Loss: 1.188955439828259
Accuracy: 0.5532, Recall: 0.5532, Precision: 0.4354, F1: 0.48659115667036484, Micro F1: 0.5532, Macro Roc Auc: 0.8096


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2/20, Training Loss: 1.0335721549533663, Validation Loss: 0.918292583994387
Accuracy: 0.6268, Recall: 0.6268, Precision: 0.6156, F1: 0.6106219175580945, Micro F1: 0.6268, Macro Roc Auc: 0.8738
Epoch: 3/20, Training Loss: 0.8803366941032987, Validation Loss: 0.893856958843845
Accuracy: 0.6247, Recall: 0.6247, Precision: 0.6314, F1: 0.6024243046810754, Micro F1: 0.6247, Macro Roc Auc: 0.8876
Epoch: 4/20, Training Loss: 0.8077230805700476, Validation Loss: 0.8742973239364096
Accuracy: 0.6377, Recall: 0.6377, Precision: 0.6354, F1: 0.6227666789530658, Micro F1: 0.6377, Macro Roc Auc: 0.8927
Epoch: 5/20, Training Loss: 0.7568199379232539, Validation Loss: 0.866938861802375
Accuracy: 0.6351, Recall: 0.6351, Precision: 0.6383, F1: 0.6195206172955718, Micro F1: 0.6351, Macro Roc Auc: 0.8921
EarlyStopping counter: 1 out of 3
Epoch: 6/20, Training Loss: 0.701800302638636, Validation Loss: 0.8892616790471193
Accuracy: 0.6195, Recall: 0.6195, Precision: 0.6147, F1: 0.6089497367381006, Micro

In [28]:
model.eval()

test_preds = []
test_labels = []

# Iterate over test data
with torch.no_grad():
    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        test_preds.append(F.softmax(outputs, dim=1).cpu().numpy())
        test_labels.append(labels.cpu().numpy())


In [29]:
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

test_preds_class = np.argmax(test_preds, axis=1)

report = classification_report(test_labels, test_preds_class, digits = 4)

print(report)

              precision    recall  f1-score   support

           0     0.6904    0.7820    0.7333       633
           1     0.5256    0.6522    0.5821       299
           2     0.5714    0.6130    0.5915       385
           3     0.6578    0.8098    0.7259       610
           4     0.6022    0.3985    0.4796       961

    accuracy                         0.6243      2888
   macro avg     0.6095    0.6511    0.6225      2888
weighted avg     0.6212    0.6243    0.6128      2888

