In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

def prepare_medical_data(csv_file, test_size=0.2, random_state=42):
    df = pd.read_csv(csv_file)
    df = df.sample(frac=1, random_state=random_state).reset_index(drop=True) 

    # DATASET LABELING
    df["label"] = df["med_field"].astype("category").cat.codes
    label2id = dict(enumerate(df["med_field"].astype("category").cat.categories))
    id2label = {v: k for k, v in label2id.items()}
    num_labels = len(id2label) 
    
    train_df, test_df = train_test_split(df, test_size=test_size, random_state=random_state)
    
    train_df = train_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)
    
    return train_df, test_df, label2id, id2label, num_labels

# PROVIDING .CSV CONTAINING COMPILATION OF ALL THE MEDICAL FIELD CONTAINING 30K QUESTION-ANSWER PAIR 
train_df, test_df, label2id, id2label, num_labels = prepare_medical_data("/kaggle/input/dataset-new/preprocessed_medical_qa (2).csv")

In [2]:
print(label2id)

{0: 'Alergist', 1: 'Cardiologist', 2: 'Dermatologist', 3: 'Endocrinologist', 4: 'Gastroenterologist', 5: 'Genetics', 6: 'Geriatrics', 7: 'Neurologist', 8: 'Oncologist/Hematologist', 9: 'Orthopedics', 10: 'OtherQA'}


In [3]:
print(id2label)

{'Alergist': 0, 'Cardiologist': 1, 'Dermatologist': 2, 'Endocrinologist': 3, 'Gastroenterologist': 4, 'Genetics': 5, 'Geriatrics': 6, 'Neurologist': 7, 'Oncologist/Hematologist': 8, 'Orthopedics': 9, 'OtherQA': 10}


In [4]:
from datasets import Dataset


train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

In [5]:
from transformers import DistilBertTokenizerFast 

model_name = "distilbert-base-uncased" 
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

max_length = 256

def tokenize_function(examples):
    return tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True, 
        max_length=max_length
    )

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Map:   0%|          | 0/20589 [00:00<?, ? examples/s]

Map:   0%|          | 0/5148 [00:00<?, ? examples/s]

In [6]:
from transformers import DistilBertForSequenceClassification

model = DistilBertForSequenceClassification.from_pretrained(
    model_name, 
    num_labels=num_labels,
    id2label=id2label, 
    label2id={label: id_ for id_, label in id2label.items()} 
)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
import logging
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [8]:
import torch
from transformers import TrainingArguments, Trainer, TrainerCallback

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(device))
else:
    device = torch.device("cpu")
    print("Using CPU")

model = model.to(device)
print("Model device:", next(model.parameters()).device)  

# Your existing imports and code for TrainingArguments, Trainer, etc.

batch_size = 8
logging_steps = 100

training_args = TrainingArguments(
    output_dir="/kaggle/working/results_distilbert",
    eval_strategy="epoch",  
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=6,  
    logging_steps=logging_steps,
    save_strategy="epoch",  
    load_best_model_at_end=True,
    save_total_limit=1,
    disable_tqdm= True,         
    report_to=["none"],
)

class PrintCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            print(f"Step: {state.global_step} | Logs: {logs}", flush=True)

    def on_epoch_end(self, args, state, control, **kwargs):
        if state.log_history:
            print(f"Epoch {state.epoch} ended. Latest log: {state.log_history[-1]}", flush=True)

def compute_metrics(eval_preds):
    import numpy as np
    import evaluate  
    
    accuracy_metric = evaluate.load("accuracy")
    precision_metric = evaluate.load("precision")
    recall_metric = evaluate.load("recall")
    f1_metric = evaluate.load("f1")
    
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=1)
    
    accuracy_result = accuracy_metric.compute(predictions=predictions, references=labels)
    precision_result = precision_metric.compute(predictions=predictions, references=labels, average='macro')
    recall_result = recall_metric.compute(predictions=predictions, references=labels, average='macro')
    f1_result = f1_metric.compute(predictions=predictions, references=labels, average='macro')
    
    return {
        "accuracy": accuracy_result["accuracy"],
        "precision": precision_result["precision"],
        "recall": recall_result["recall"],
        "f1": f1_result["f1"],
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[PrintCallback]
)

trainer.train()

Using GPU: Tesla T4
Model device: cuda:0


  trainer = Trainer(


Step: 100 | Logs: {'loss': 1.5529, 'grad_norm': 5.017885208129883, 'learning_rate': 4.935249935249936e-05, 'epoch': 0.0777000777000777}
{'loss': 1.5529, 'grad_norm': 5.017885208129883, 'learning_rate': 4.935249935249936e-05, 'epoch': 0.0777000777000777}
Step: 200 | Logs: {'loss': 0.8407, 'grad_norm': 6.759465217590332, 'learning_rate': 4.8704998704998705e-05, 'epoch': 0.1554001554001554}
{'loss': 0.8407, 'grad_norm': 6.759465217590332, 'learning_rate': 4.8704998704998705e-05, 'epoch': 0.1554001554001554}
Step: 300 | Logs: {'loss': 0.5337, 'grad_norm': 5.941342830657959, 'learning_rate': 4.805749805749806e-05, 'epoch': 0.2331002331002331}
{'loss': 0.5337, 'grad_norm': 5.941342830657959, 'learning_rate': 4.805749805749806e-05, 'epoch': 0.2331002331002331}
Step: 400 | Logs: {'loss': 0.454, 'grad_norm': 5.074390411376953, 'learning_rate': 4.7409997409997415e-05, 'epoch': 0.3108003108003108}
{'loss': 0.454, 'grad_norm': 5.074390411376953, 'learning_rate': 4.7409997409997415e-05, 'epoch': 0.

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.56k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.38k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.79k [00:00<?, ?B/s]

Step: 1287 | Logs: {'eval_loss': 0.23569965362548828, 'eval_accuracy': 0.9085081585081585, 'eval_precision': 0.9554672134815528, 'eval_recall': 0.9238466703591698, 'eval_f1': 0.9325129266468789, 'eval_runtime': 26.4706, 'eval_samples_per_second': 194.48, 'eval_steps_per_second': 12.164, 'epoch': 1.0}
{'eval_loss': 0.23569965362548828, 'eval_accuracy': 0.9085081585081585, 'eval_precision': 0.9554672134815528, 'eval_recall': 0.9238466703591698, 'eval_f1': 0.9325129266468789, 'eval_runtime': 26.4706, 'eval_samples_per_second': 194.48, 'eval_steps_per_second': 12.164, 'epoch': 1.0}




Step: 1300 | Logs: {'loss': 0.241, 'grad_norm': 3.7618138790130615, 'learning_rate': 4.158249158249159e-05, 'epoch': 1.0101010101010102}
{'loss': 0.241, 'grad_norm': 3.7618138790130615, 'learning_rate': 4.158249158249159e-05, 'epoch': 1.0101010101010102}
Step: 1400 | Logs: {'loss': 0.2192, 'grad_norm': 1.0282785892486572, 'learning_rate': 4.0934990934990935e-05, 'epoch': 1.0878010878010878}
{'loss': 0.2192, 'grad_norm': 1.0282785892486572, 'learning_rate': 4.0934990934990935e-05, 'epoch': 1.0878010878010878}
Step: 1500 | Logs: {'loss': 0.2018, 'grad_norm': 1.6009060144424438, 'learning_rate': 4.028749028749029e-05, 'epoch': 1.1655011655011656}
{'loss': 0.2018, 'grad_norm': 1.6009060144424438, 'learning_rate': 4.028749028749029e-05, 'epoch': 1.1655011655011656}
Step: 1600 | Logs: {'loss': 0.2072, 'grad_norm': 1.3980671167373657, 'learning_rate': 3.9639989639989645e-05, 'epoch': 1.2432012432012431}
{'loss': 0.2072, 'grad_norm': 1.3980671167373657, 'learning_rate': 3.9639989639989645e-05,



Step: 2600 | Logs: {'loss': 0.2067, 'grad_norm': 4.561190605163574, 'learning_rate': 3.3164983164983165e-05, 'epoch': 2.0202020202020203}
{'loss': 0.2067, 'grad_norm': 4.561190605163574, 'learning_rate': 3.3164983164983165e-05, 'epoch': 2.0202020202020203}
Step: 2700 | Logs: {'loss': 0.1753, 'grad_norm': 2.4652364253997803, 'learning_rate': 3.251748251748252e-05, 'epoch': 2.097902097902098}
{'loss': 0.1753, 'grad_norm': 2.4652364253997803, 'learning_rate': 3.251748251748252e-05, 'epoch': 2.097902097902098}
Step: 2800 | Logs: {'loss': 0.1886, 'grad_norm': 3.8966193199157715, 'learning_rate': 3.1869981869981875e-05, 'epoch': 2.1756021756021755}
{'loss': 0.1886, 'grad_norm': 3.8966193199157715, 'learning_rate': 3.1869981869981875e-05, 'epoch': 2.1756021756021755}
Step: 2900 | Logs: {'loss': 0.1684, 'grad_norm': 1.177351713180542, 'learning_rate': 3.122248122248122e-05, 'epoch': 2.2533022533022535}
{'loss': 0.1684, 'grad_norm': 1.177351713180542, 'learning_rate': 3.122248122248122e-05, 'ep



Step: 3900 | Logs: {'loss': 0.1891, 'grad_norm': 4.002619743347168, 'learning_rate': 2.474747474747475e-05, 'epoch': 3.0303030303030303}
{'loss': 0.1891, 'grad_norm': 4.002619743347168, 'learning_rate': 2.474747474747475e-05, 'epoch': 3.0303030303030303}
Step: 4000 | Logs: {'loss': 0.1726, 'grad_norm': 0.9375525712966919, 'learning_rate': 2.40999740999741e-05, 'epoch': 3.108003108003108}
{'loss': 0.1726, 'grad_norm': 0.9375525712966919, 'learning_rate': 2.40999740999741e-05, 'epoch': 3.108003108003108}
Step: 4100 | Logs: {'loss': 0.1652, 'grad_norm': 1.4255621433258057, 'learning_rate': 2.3452473452473453e-05, 'epoch': 3.185703185703186}
{'loss': 0.1652, 'grad_norm': 1.4255621433258057, 'learning_rate': 2.3452473452473453e-05, 'epoch': 3.185703185703186}
Step: 4200 | Logs: {'loss': 0.1778, 'grad_norm': 2.211275815963745, 'learning_rate': 2.2804972804972807e-05, 'epoch': 3.2634032634032635}
{'loss': 0.1778, 'grad_norm': 2.211275815963745, 'learning_rate': 2.2804972804972807e-05, 'epoch'



Step: 5200 | Logs: {'loss': 0.1496, 'grad_norm': 1.5915040969848633, 'learning_rate': 1.632996632996633e-05, 'epoch': 4.040404040404041}
{'loss': 0.1496, 'grad_norm': 1.5915040969848633, 'learning_rate': 1.632996632996633e-05, 'epoch': 4.040404040404041}
Step: 5300 | Logs: {'loss': 0.1575, 'grad_norm': 2.696770191192627, 'learning_rate': 1.5682465682465683e-05, 'epoch': 4.118104118104118}
{'loss': 0.1575, 'grad_norm': 2.696770191192627, 'learning_rate': 1.5682465682465683e-05, 'epoch': 4.118104118104118}
Step: 5400 | Logs: {'loss': 0.1745, 'grad_norm': 2.1959807872772217, 'learning_rate': 1.5034965034965034e-05, 'epoch': 4.195804195804196}
{'loss': 0.1745, 'grad_norm': 2.1959807872772217, 'learning_rate': 1.5034965034965034e-05, 'epoch': 4.195804195804196}
Step: 5500 | Logs: {'loss': 0.1801, 'grad_norm': 0.2732304334640503, 'learning_rate': 1.4387464387464389e-05, 'epoch': 4.273504273504273}
{'loss': 0.1801, 'grad_norm': 0.2732304334640503, 'learning_rate': 1.4387464387464389e-05, 'epo



Step: 6500 | Logs: {'loss': 0.1653, 'grad_norm': 3.4265902042388916, 'learning_rate': 7.912457912457913e-06, 'epoch': 5.05050505050505}
{'loss': 0.1653, 'grad_norm': 3.4265902042388916, 'learning_rate': 7.912457912457913e-06, 'epoch': 5.05050505050505}
Step: 6600 | Logs: {'loss': 0.1563, 'grad_norm': 4.222074031829834, 'learning_rate': 7.264957264957266e-06, 'epoch': 5.128205128205128}
{'loss': 0.1563, 'grad_norm': 4.222074031829834, 'learning_rate': 7.264957264957266e-06, 'epoch': 5.128205128205128}
Step: 6700 | Logs: {'loss': 0.1648, 'grad_norm': 2.350196361541748, 'learning_rate': 6.617456617456617e-06, 'epoch': 5.205905205905206}
{'loss': 0.1648, 'grad_norm': 2.350196361541748, 'learning_rate': 6.617456617456617e-06, 'epoch': 5.205905205905206}
Step: 6800 | Logs: {'loss': 0.1581, 'grad_norm': 1.6796776056289673, 'learning_rate': 5.96995596995597e-06, 'epoch': 5.283605283605284}
{'loss': 0.1581, 'grad_norm': 1.6796776056289673, 'learning_rate': 5.96995596995597e-06, 'epoch': 5.28360



Step: 7722 | Logs: {'eval_loss': 0.25585126876831055, 'eval_accuracy': 0.918026418026418, 'eval_precision': 0.9515629845164707, 'eval_recall': 0.9535538531497736, 'eval_f1': 0.9473877938855221, 'eval_runtime': 25.656, 'eval_samples_per_second': 200.655, 'eval_steps_per_second': 12.551, 'epoch': 6.0}
{'eval_loss': 0.25585126876831055, 'eval_accuracy': 0.918026418026418, 'eval_precision': 0.9515629845164707, 'eval_recall': 0.9535538531497736, 'eval_f1': 0.9473877938855221, 'eval_runtime': 25.656, 'eval_samples_per_second': 200.655, 'eval_steps_per_second': 12.551, 'epoch': 6.0}
Step: 7722 | Logs: {'train_runtime': 2014.0643, 'train_samples_per_second': 61.336, 'train_steps_per_second': 3.834, 'total_flos': 8183427060243456.0, 'train_loss': 0.2247829725013067, 'epoch': 6.0}
{'train_runtime': 2014.0643, 'train_samples_per_second': 61.336, 'train_steps_per_second': 3.834, 'train_loss': 0.2247829725013067, 'epoch': 6.0}


TrainOutput(global_step=7722, training_loss=0.2247829725013067, metrics={'train_runtime': 2014.0643, 'train_samples_per_second': 61.336, 'train_steps_per_second': 3.834, 'train_loss': 0.2247829725013067, 'epoch': 6.0})

In [10]:
metrics = trainer.evaluate()
print(metrics)



Step: 7722 | Logs: {'eval_loss': 0.22451472282409668, 'eval_accuracy': 0.9184149184149184, 'eval_precision': 0.9520161110523737, 'eval_recall': 0.9547565244822548, 'eval_f1': 0.9481891240500308, 'eval_runtime': 25.6662, 'eval_samples_per_second': 200.575, 'eval_steps_per_second': 12.546, 'epoch': 6.0}
{'eval_loss': 0.22451472282409668, 'eval_accuracy': 0.9184149184149184, 'eval_precision': 0.9520161110523737, 'eval_recall': 0.9547565244822548, 'eval_f1': 0.9481891240500308, 'eval_runtime': 25.6662, 'eval_samples_per_second': 200.575, 'eval_steps_per_second': 12.546, 'epoch': 6.0}
{'eval_loss': 0.22451472282409668, 'eval_accuracy': 0.9184149184149184, 'eval_precision': 0.9520161110523737, 'eval_recall': 0.9547565244822548, 'eval_f1': 0.9481891240500308, 'eval_runtime': 25.6662, 'eval_samples_per_second': 200.575, 'eval_steps_per_second': 12.546, 'epoch': 6.0}


In [11]:
test_text = "Patient has chronic glossitis and possible environmental inhalant allergies..."
inputs = tokenizer(test_text, return_tensors="pt", truncation=True, padding=True, max_length=256)

inputs = {key: value.to(device) for key, value in inputs.items()}

model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = logits.argmax().item()

predicted_label = label2id.get(predicted_class_id)
print("Predicted medical field:", predicted_label)

Predicted medical field: Dermatologist


In [20]:
print(id2label)
print(label2id)

{'Alergist': 0, 'Cardiologist': 1, 'Dermatologist': 2, 'Endocrinologist': 3, 'Gastroenterologist': 4, 'Genetics': 5, 'Geriatrics': 6, 'Neurologist': 7, 'Oncologist/Hematologist': 8, 'Orthopedics': 9, 'OtherQA': 10}
{0: 'Alergist', 1: 'Cardiologist', 2: 'Dermatologist', 3: 'Endocrinologist', 4: 'Gastroenterologist', 5: 'Genetics', 6: 'Geriatrics', 7: 'Neurologist', 8: 'Oncologist/Hematologist', 9: 'Orthopedics', 10: 'OtherQA'}


In [30]:
import torch

input_text = "A 58‑year‑old patient presents with intermittent chest pain, shortness of breath on exertion, and palpitations. Which specialist should they be referred to?"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=256)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}

model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = torch.argmax(logits, dim=1).item()


predicted_label = id2label[predicted_class_id]
print("Predicted medical field:", predicted_label)

Predicted medical field: Cardiologist
