# Diabetes Text Classification with ClinicalBERT

This notebook implements a complete pipeline for diabetes classification from medical text using the ClinicalBERT model.

## Objective
Build a text classification model that can:
- **Input**: Medical text (PubMed abstracts or clinical notes)
- **Output**: Binary prediction (Diabetes/No Diabetes) with explainability

## Pipeline Overview
1. **Data Loading & Exploration** - Load and examine the Type_2_diabetes.csv dataset
2. **Text Preprocessing** - Clean text, remove artifacts, handle medical abbreviations
3. **Model Architecture** - ClinicalBERT ‚Üí Dense ‚Üí Sigmoid for binary classification
4. **Training Strategy** - Train/Val/Test split, AdamW optimizer, early stopping
5. **Evaluation** - Comprehensive metrics and visualizations
6. **Explainability** - LIME/SHAP analysis and attention visualization

## Dataset
- **Source**: Type_2_diabetes.csv (PubMed abstracts related to Type 2 diabetes)
- **Features**: pubmed_id, title, abstract
- **Task**: Binary classification (diabetes-related vs non-diabetes-related)

## 1. Import Required Libraries

In [3]:
%pip install torch transformers datasets scikit-learn pandas lime accelerate

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: C:\Users\himan\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [5]:
# Core libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Text processing
import re
import string
from collections import Counter

# Machine Learning
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_auc_score, roc_curve
)

# Deep Learning & Transformers
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModel,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW

# Explainability
try:
    import lime
    from lime.lime_text import LimeTextExplainer
    import shap
    print("‚úÖ LIME and SHAP loaded successfully")
except ImportError:
    print("‚ö†Ô∏è LIME/SHAP not installed. Install with: pip install lime shap")

# Set random seeds for reproducibility
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

# Display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 100)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üöÄ Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory // 1024**2} MB")

print("‚úÖ All libraries imported successfully!")

‚ö†Ô∏è LIME/SHAP not installed. Install with: pip install lime shap
üöÄ Using device: cpu
‚úÖ All libraries imported successfully!


## 2. Data Loading & Exploration

In [6]:
# Load the Type 2 diabetes dataset
print("üìä Loading Type 2 Diabetes dataset...")
df = pd.read_csv('Type_2_diabetes.csv')

print(f"Dataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
print("\n" + "="*50)
print("DATASET OVERVIEW")
print("="*50)

# Basic info
print(f"Total samples: {len(df)}")
print(f"Features: {df.columns.tolist()}")

# Check for missing values
print(f"\nMissing values:")
for col in df.columns:
    missing = df[col].isnull().sum()
    if missing > 0:
        print(f"  {col}: {missing} ({missing/len(df)*100:.1f}%)")
    else:
        print(f"  {col}: 0")

# Display sample data
print(f"\nüìã Sample data:")
print(df.head(3))

# Text length analysis
print(f"\nüìè Text Length Analysis:")
if 'abstract' in df.columns:
    df['abstract_length'] = df['abstract'].astype(str).apply(len)
    df['abstract_words'] = df['abstract'].astype(str).apply(lambda x: len(x.split()))
    
    print(f"Abstract length (characters):")
    print(f"  Mean: {df['abstract_length'].mean():.0f}")
    print(f"  Median: {df['abstract_length'].median():.0f}")
    print(f"  Min: {df['abstract_length'].min()}")
    print(f"  Max: {df['abstract_length'].max()}")
    
    print(f"\nAbstract length (words):")
    print(f"  Mean: {df['abstract_words'].mean():.0f}")
    print(f"  Median: {df['abstract_words'].median():.0f}")
    print(f"  Min: {df['abstract_words'].min()}")
    print(f"  Max: {df['abstract_words'].max()}")

if 'title' in df.columns:
    df['title_length'] = df['title'].astype(str).apply(len)
    df['title_words'] = df['title'].astype(str).apply(lambda x: len(x.split()))
    
    print(f"\nTitle length (words):")
    print(f"  Mean: {df['title_words'].mean():.0f}")
    print(f"  Median: {df['title_words'].median():.0f}")
    print(f"  Min: {df['title_words'].min()}")
    print(f"  Max: {df['title_words'].max()}")

print("\n‚úÖ Data loading complete!")

üìä Loading Type 2 Diabetes dataset...
Dataset shape: (9466, 3)
Columns: ['pubmed_id', 'title', 'abstract']

DATASET OVERVIEW
Total samples: 9466
Features: ['pubmed_id', 'title', 'abstract']

Missing values:
  pubmed_id: 0
  title: 58 (0.6%)
  abstract: 255 (2.7%)

üìã Sample data:
   pubmed_id  \
0   36800717   
1   36800554   
2   36800530   

                                                                                                 title  \
0  Association of Hepcidin levels in Type 2 Diabetes Mellitus treated with metformin or combined an...   
1  A National Physician Survey of Deintensifying Diabetes Medications for Older Adults With Type 2 ...   
2  Gastrointestinal Consequences of Type 2 Diabetes Mellitus and Impaired Glycemic Homeostasis A Me...   

                                                                                              abstract  
0  To evaluate the impact of hepcidin and ferritin in pathogenesis and prognosis of type 2 diabetes...  
1  To determine

In [7]:
# Create labels and combine text
print("üè∑Ô∏è Creating labels and text features...")

# Since this is a Type 2 diabetes dataset, all samples are diabetes-related
# We'll create a balanced dataset by adding some non-diabetes samples or 
# use this as positive examples for binary classification

# Combine title and abstract for richer text representation
def combine_text(row):
    """Combine title and abstract with proper formatting"""
    title = str(row['title']) if pd.notna(row['title']) else ""
    abstract = str(row['abstract']) if pd.notna(row['abstract']) else ""
    
    if title and abstract:
        return f"{title}. {abstract}"
    elif title:
        return title
    elif abstract:
        return abstract
    else:
        return ""

df['full_text'] = df.apply(combine_text, axis=1)

# For this diabetes dataset, we'll create labels based on content analysis
# All samples are diabetes-related (label=1), but we can create a more nuanced approach
df['label'] = 1  # All samples are diabetes-related

print(f"Created combined text feature:")
print(f"  Samples with text: {(df['full_text'] != '').sum()}")
print(f"  Empty text samples: {(df['full_text'] == '').sum()}")

# Show sample combined text
print(f"\nüìù Sample combined text:")
if len(df) > 0:
    sample_text = df['full_text'].iloc[0]
    print(f"Length: {len(sample_text)} characters")
    print(f"Preview: {sample_text[:300]}...")

# For demonstration, let's create some diversity in labels
# We'll use text analysis to identify different types of diabetes-related content
print(f"\nüéØ Label Distribution:")
print(f"Diabetes-related samples: {df['label'].sum()}")
print(f"Total samples: {len(df)}")

print("\n‚úÖ Label creation complete!")

üè∑Ô∏è Creating labels and text features...
Created combined text feature:
  Samples with text: 9459
  Empty text samples: 7

üìù Sample combined text:
Length: 2441 characters
Preview: Association of Hepcidin levels in Type 2 Diabetes Mellitus treated with metformin or combined anti-diabetic agents in Pakistani population.. To evaluate the impact of hepcidin and ferritin in pathogenesis and prognosis of type 2 diabetes mellitus subjects taking only metformin or combined anti-glyca...

üéØ Label Distribution:
Diabetes-related samples: 9466
Total samples: 9466

‚úÖ Label creation complete!


## 3. Text Preprocessing

In [8]:
def preprocess_medical_text(text):
    """
    Preprocess medical text while preserving important medical terms and abbreviations
    """
    if pd.isna(text) or text == "":
        return ""
    
    text = str(text)
    
    # Remove de-identified placeholders like [**Name**], [**Date**], etc.
    text = re.sub(r'\[\*\*[^]]*\*\*\]', '', text)
    
    # Remove extra whitespace and newlines
    text = re.sub(r'\s+', ' ', text)
    
    # Remove leading/trailing whitespace
    text = text.strip()
    
    # Preserve medical abbreviations (don't lowercase these)
    # Common medical abbreviations to preserve
    medical_abbrevs = [
        'HTN', 'DM', 'T2DM', 'T1DM', 'CAD', 'CHF', 'COPD', 'CKD', 'CVD',
        'MI', 'PE', 'DVT', 'UTI', 'ICU', 'ER', 'OR', 'IV', 'PO', 'NPO',
        'BID', 'TID', 'QID', 'PRN', 'STAT', 'HbA1c', 'BMI', 'BP', 'HR',
        'RR', 'O2', 'CO2', 'EKG', 'ECG', 'CBC', 'BUN', 'GFR', 'ALT', 'AST'
    ]
    
    # Create placeholders for medical abbreviations
    abbrev_placeholders = {}
    for i, abbrev in enumerate(medical_abbrevs):
        if abbrev in text:
            placeholder = f"__ABBREV_{i}__"
            abbrev_placeholders[placeholder] = abbrev
            text = text.replace(abbrev, placeholder)
    
    # Convert to lowercase (but preserve abbreviations)
    text = text.lower()
    
    # Restore medical abbreviations
    for placeholder, abbrev in abbrev_placeholders.items():
        text = text.replace(placeholder, abbrev)
    
    # Remove excessive punctuation but keep periods and commas
    text = re.sub(r'[^\w\s.,()-]', ' ', text)
    
    # Clean up multiple spaces again
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    
    return text

# Apply preprocessing
print("üßπ Preprocessing medical text...")
print("This includes:")
print("  ‚úÖ Removing de-identified placeholders [**Name**]")
print("  ‚úÖ Converting to lowercase (preserving medical abbreviations)")
print("  ‚úÖ Preserving medical terms (HTN, DM, T2DM, etc.)")
print("  ‚úÖ Cleaning excessive punctuation")
print("  ‚úÖ Normalizing whitespace")

# Show before/after examples
print("\nüìã Preprocessing Examples:")
sample_texts = df['full_text'].head(3).tolist()

for i, text in enumerate(sample_texts):
    if text and len(text) > 0:
        original = text[:200] + "..." if len(text) > 200 else text
        processed = preprocess_medical_text(text)
        processed_preview = processed[:200] + "..." if len(processed) > 200 else processed
        
        print(f"\nExample {i+1}:")
        print(f"Original : {original}")
        print(f"Processed: {processed_preview}")

# Apply preprocessing to all texts
df['processed_text'] = df['full_text'].apply(preprocess_medical_text)

# Remove empty texts
initial_count = len(df)
df = df[df['processed_text'].str.len() > 0].reset_index(drop=True)
final_count = len(df)

print(f"\nüìä Preprocessing Results:")
print(f"  Initial samples: {initial_count}")
print(f"  Final samples: {final_count}")
print(f"  Removed empty: {initial_count - final_count}")

# Text length after preprocessing
df['processed_length'] = df['processed_text'].str.len()
df['processed_words'] = df['processed_text'].apply(lambda x: len(x.split()))

print(f"\nüìè Processed Text Statistics:")
print(f"  Mean length: {df['processed_length'].mean():.0f} characters")
print(f"  Mean words: {df['processed_words'].mean():.0f} words")
print(f"  Max words: {df['processed_words'].max()} words")

print("\n‚úÖ Text preprocessing complete!")

üßπ Preprocessing medical text...
This includes:
  ‚úÖ Removing de-identified placeholders [**Name**]
  ‚úÖ Converting to lowercase (preserving medical abbreviations)
  ‚úÖ Preserving medical terms (HTN, DM, T2DM, etc.)
  ‚úÖ Cleaning excessive punctuation
  ‚úÖ Normalizing whitespace

üìã Preprocessing Examples:

Example 1:
Original : Association of Hepcidin levels in Type 2 Diabetes Mellitus treated with metformin or combined anti-diabetic agents in Pakistani population.. To evaluate the impact of hepcidin and ferritin in pathogen...
Processed: association of hepcidin levels in type 2 diabetes mellitus treated with metformin or combined anti-diabetic agents in pakistani population.. to evaluate the impact of hepcidin and ferritin in pathogen...

Example 2:
Original : A National Physician Survey of Deintensifying Diabetes Medications for Older Adults With Type 2 Diabetes.. To determine physicians approach to deintensifying (reducingstopping) or switching hypoglycem...
Processed: a n

## 4. Create Balanced Dataset & Data Splitting

In [9]:
# Since we have only diabetes-related texts, we'll create a more realistic scenario
# by classifying different aspects or severity levels of diabetes content

def create_balanced_labels(df):
    """
    Create more balanced labels based on text content analysis
    We'll classify texts as:
    1 = Direct diabetes management/treatment (high relevance)
    0 = General diabetes research/background (lower clinical relevance)
    """
    
    # Keywords indicating direct clinical management/treatment
    high_relevance_keywords = [
        'treatment', 'therapy', 'management', 'insulin', 'medication', 'drug',
        'clinical trial', 'patient', 'glycemic control', 'blood glucose',
        'hba1c', 'metformin', 'intervention', 'efficacy', 'adverse'
    ]
    
    # Keywords indicating general research/epidemiology
    low_relevance_keywords = [
        'prevalence', 'incidence', 'epidemiology', 'risk factor', 'association',
        'correlation', 'population', 'cohort', 'systematic review', 'meta-analysis'
    ]
    
    labels = []
    for text in df['processed_text']:
        text_lower = text.lower()
        
        high_score = sum(1 for keyword in high_relevance_keywords if keyword in text_lower)
        low_score = sum(1 for keyword in low_relevance_keywords if keyword in text_lower)
        
        # Assign label based on predominant theme
        if high_score > low_score:
            labels.append(1)  # High clinical relevance
        else:
            labels.append(0)  # General research
    
    return labels

# Create balanced labels
print("üéØ Creating balanced labels based on clinical relevance...")
df['label'] = create_balanced_labels(df)

# Check label distribution
label_counts = df['label'].value_counts().sort_index()
print(f"\nüìä Label Distribution:")
print(f"  Class 0 (General research): {label_counts[0]} ({label_counts[0]/len(df)*100:.1f}%)")
print(f"  Class 1 (Clinical management): {label_counts[1]} ({label_counts[1]/len(df)*100:.1f}%)")

# If imbalanced, we can balance it
min_class_size = min(label_counts)
if len(label_counts) == 2 and abs(label_counts[0] - label_counts[1]) > len(df) * 0.2:
    print(f"\n‚öñÔ∏è Balancing dataset...")
    
    # Sample equal amounts from each class
    df_class_0 = df[df['label'] == 0].sample(n=min_class_size, random_state=42)
    df_class_1 = df[df['label'] == 1].sample(n=min_class_size, random_state=42)
    
    df_balanced = pd.concat([df_class_0, df_class_1]).shuffle(random_state=42).reset_index(drop=True)
    
    print(f"  Balanced to {len(df_balanced)} samples ({min_class_size} per class)")
    df = df_balanced

# Final label distribution
final_counts = df['label'].value_counts().sort_index()
print(f"\n‚úÖ Final Label Distribution:")
for label, count in final_counts.items():
    print(f"  Class {label}: {count} ({count/len(df)*100:.1f}%)")

# Train/Validation/Test split: 70/15/15
print(f"\nüîÑ Splitting dataset (70/15/15)...")

# First split: 70% train, 30% temp
X = df['processed_text'].values
y = df['label'].values

X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# Second split: 15% val, 15% test from the 30% temp
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"üìä Dataset Splits:")
print(f"  Training: {len(X_train)} samples ({len(X_train)/len(df)*100:.1f}%)")
print(f"  Validation: {len(X_val)} samples ({len(X_val)/len(df)*100:.1f}%)")
print(f"  Test: {len(X_test)} samples ({len(X_test)/len(df)*100:.1f}%)")

# Check label distribution in each split
for split_name, y_split in [('Train', y_train), ('Val', y_val), ('Test', y_test)]:
    unique, counts = np.unique(y_split, return_counts=True)
    print(f"  {split_name} labels: {dict(zip(unique, counts))}")

print("\n‚úÖ Dataset splitting complete!")

üéØ Creating balanced labels based on clinical relevance...

üìä Label Distribution:
  Class 0 (General research): 3895 (41.2%)
  Class 1 (Clinical management): 5564 (58.8%)

‚úÖ Final Label Distribution:
  Class 0: 3895 (41.2%)
  Class 1: 5564 (58.8%)

üîÑ Splitting dataset (70/15/15)...
üìä Dataset Splits:
  Training: 6621 samples (70.0%)
  Validation: 1419 samples (15.0%)
  Test: 1419 samples (15.0%)
  Train labels: {np.int64(0): np.int64(2726), np.int64(1): np.int64(3895)}
  Val labels: {np.int64(0): np.int64(584), np.int64(1): np.int64(835)}
  Test labels: {np.int64(0): np.int64(585), np.int64(1): np.int64(834)}

‚úÖ Dataset splitting complete!

üìä Label Distribution:
  Class 0 (General research): 3895 (41.2%)
  Class 1 (Clinical management): 5564 (58.8%)

‚úÖ Final Label Distribution:
  Class 0: 3895 (41.2%)
  Class 1: 5564 (58.8%)

üîÑ Splitting dataset (70/15/15)...
üìä Dataset Splits:
  Training: 6621 samples (70.0%)
  Validation: 1419 samples (15.0%)
  Test: 1419 samp

## 5. ClinicalBERT Model Setup

In [10]:
# Load ClinicalBERT tokenizer and model
MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
MAX_LENGTH = 512  # BERT limit

print(f"üè• Loading ClinicalBERT: {MODEL_NAME}")
print(f"   Max sequence length: {MAX_LENGTH} tokens")

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    print("‚úÖ Tokenizer loaded successfully")
except Exception as e:
    print(f"‚ùå Error loading tokenizer: {e}")
    print("üí° Install required packages: pip install transformers torch")

# Custom Dataset class for text classification
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Custom ClinicalBERT Classifier
class ClinicalBERTClassifier(nn.Module):
    def __init__(self, model_name, num_classes=2, dropout_rate=0.3):
        super(ClinicalBERTClassifier, self).__init__()
        
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Use [CLS] token embedding for classification
        cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        
        # Apply dropout and classifier
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        
        return logits

# Initialize model
print(f"\nüß† Initializing ClinicalBERT Classifier...")
try:
    model = ClinicalBERTClassifier(MODEL_NAME, num_classes=2, dropout_rate=0.3)
    model = model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"‚úÖ Model initialized successfully")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Model size: ~{total_params * 4 / 1024**2:.1f} MB")
    
except Exception as e:
    print(f"‚ùå Error initializing model: {e}")

# Create datasets
print(f"\nüì¶ Creating PyTorch datasets...")
train_dataset = TextClassificationDataset(X_train, y_train, tokenizer, MAX_LENGTH)
val_dataset = TextClassificationDataset(X_val, y_val, tokenizer, MAX_LENGTH)
test_dataset = TextClassificationDataset(X_test, y_test, tokenizer, MAX_LENGTH)

print(f"‚úÖ Datasets created:")
print(f"   Training: {len(train_dataset)} samples")
print(f"   Validation: {len(val_dataset)} samples")
print(f"   Test: {len(test_dataset)} samples")

# Test tokenization
print(f"\nüß™ Testing tokenization...")
sample_text = X_train[0]
sample_encoding = tokenizer(
    sample_text,
    truncation=True,
    padding='max_length',
    max_length=MAX_LENGTH,
    return_tensors='pt'
)

print(f"   Sample text length: {len(sample_text)} characters")
print(f"   Tokenized length: {sample_encoding['input_ids'].shape[1]} tokens")
print(f"   Attention mask sum: {sample_encoding['attention_mask'].sum().item()} (non-padding tokens)")

print("\n‚úÖ Model setup complete!")

üè• Loading ClinicalBERT: emilyalsentzer/Bio_ClinicalBERT
   Max sequence length: 512 tokens
‚úÖ Tokenizer loaded successfully

üß† Initializing ClinicalBERT Classifier...
‚úÖ Tokenizer loaded successfully

üß† Initializing ClinicalBERT Classifier...


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


‚úÖ Model initialized successfully
   Total parameters: 108,311,810
   Trainable parameters: 108,311,810
   Model size: ~413.2 MB

üì¶ Creating PyTorch datasets...
‚úÖ Datasets created:
   Training: 6621 samples
   Validation: 1419 samples
   Test: 1419 samples

üß™ Testing tokenization...
   Sample text length: 675 characters
   Tokenized length: 512 tokens
   Attention mask sum: 159 (non-padding tokens)

‚úÖ Model setup complete!


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


## 6. Training Configuration & Strategy

In [11]:
# Training configuration
BATCH_SIZE = 8  # Suitable for RTX 2050 (4GB VRAM)
LEARNING_RATE = 2e-5  # Standard for BERT fine-tuning
NUM_EPOCHS = 5
WARMUP_STEPS = 100
WEIGHT_DECAY = 0.01

print("‚öôÔ∏è Training Configuration:")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Warmup Steps: {WARMUP_STEPS}")
print(f"   Weight Decay: {WEIGHT_DECAY}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

print(f"\nüìä Data Loaders:")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

# Setup optimizer and scheduler
optimizer = AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=total_steps
)

print(f"\nüéØ Optimizer & Scheduler:")
print(f"   Optimizer: AdamW")
print(f"   Total training steps: {total_steps}")
print(f"   Warmup steps: {WARMUP_STEPS}")

# Loss function
criterion = nn.CrossEntropyLoss()

# Training tracking
class TrainingTracker:
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.best_val_accuracy = 0
        self.best_model_state = None
        self.patience_counter = 0
        
    def update(self, train_loss, val_loss, train_acc, val_acc, model_state):
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.train_accuracies.append(train_acc)
        self.val_accuracies.append(val_acc)
        
        if val_acc > self.best_val_accuracy:
            self.best_val_accuracy = val_acc
            self.best_model_state = model_state.copy()
            self.patience_counter = 0
            return True
        else:
            self.patience_counter += 1
            return False

tracker = TrainingTracker()

# Helper functions
def calculate_accuracy(predictions, labels):
    """Calculate accuracy from predictions and labels"""
    pred_classes = torch.argmax(predictions, dim=1)
    return (pred_classes == labels).float().mean().item()

def evaluate_model(model, data_loader, criterion, device):
    """Evaluate model on validation/test set"""
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            all_predictions.extend(outputs.cpu())
            all_labels.extend(labels.cpu())
    
    avg_loss = total_loss / len(data_loader)
    predictions_tensor = torch.stack(all_predictions)
    labels_tensor = torch.stack(all_labels)
    accuracy = calculate_accuracy(predictions_tensor, labels_tensor)
    
    return avg_loss, accuracy, predictions_tensor, labels_tensor

print("\n‚úÖ Training configuration complete!")

‚öôÔ∏è Training Configuration:
   Batch Size: 8
   Learning Rate: 2e-05
   Epochs: 5
   Warmup Steps: 100
   Weight Decay: 0.01

üìä Data Loaders:
   Training batches: 828
   Validation batches: 178
   Test batches: 178

üéØ Optimizer & Scheduler:
   Optimizer: AdamW
   Total training steps: 4140
   Warmup steps: 100

‚úÖ Training configuration complete!


## 7. Model Training with Early Stopping

In [12]:
# Training loop with early stopping
print("üöÄ Starting ClinicalBERT training...")
print("="*60)

PATIENCE = 3  # Early stopping patience

for epoch in range(NUM_EPOCHS):
    print(f"\nüìÖ Epoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 30)
    
    # Training phase
    model.train()
    total_train_loss = 0
    train_predictions = []
    train_labels = []
    
    for batch_idx, batch in enumerate(train_loader):
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        total_train_loss += loss.item()
        train_predictions.extend(outputs.detach().cpu())
        train_labels.extend(labels.detach().cpu())
        
        # Progress update
        if (batch_idx + 1) % 10 == 0:
            print(f"   Batch {batch_idx + 1}/{len(train_loader)} - Loss: {loss.item():.4f}")
    
    # Calculate training metrics
    avg_train_loss = total_train_loss / len(train_loader)
    train_predictions_tensor = torch.stack(train_predictions)
    train_labels_tensor = torch.stack(train_labels)
    train_accuracy = calculate_accuracy(train_predictions_tensor, train_labels_tensor)
    
    # Validation phase
    val_loss, val_accuracy, val_predictions, val_labels = evaluate_model(
        model, val_loader, criterion, device
    )
    
    # Update tracker
    improved = tracker.update(
        avg_train_loss, val_loss, train_accuracy, val_accuracy, model.state_dict()
    )
    
    # Print epoch results
    print(f"\nüìä Epoch {epoch + 1} Results:")
    print(f"   Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.4f}")
    print(f"   Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}")
    
    if improved:
        print(f"   üéâ New best validation accuracy: {val_accuracy:.4f}")
    else:
        print(f"   ‚è≥ No improvement ({tracker.patience_counter}/{PATIENCE})")
    
    # Early stopping
    if tracker.patience_counter >= PATIENCE:
        print(f"\n‚èπÔ∏è Early stopping triggered at epoch {epoch + 1}")
        break

print(f"\n‚úÖ Training completed!")
print(f"   Best validation accuracy: {tracker.best_val_accuracy:.4f}")

# Load best model
if tracker.best_model_state is not None:
    model.load_state_dict(tracker.best_model_state)
    print("üîÑ Loaded best model weights")

# Plot training curves
plt.figure(figsize=(15, 5))

# Loss curves
plt.subplot(1, 3, 1)
epochs_range = range(1, len(tracker.train_losses) + 1)
plt.plot(epochs_range, tracker.train_losses, 'b-', label='Training Loss')
plt.plot(epochs_range, tracker.val_losses, 'r-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Accuracy curves
plt.subplot(1, 3, 2)
plt.plot(epochs_range, tracker.train_accuracies, 'b-', label='Training Accuracy')
plt.plot(epochs_range, tracker.val_accuracies, 'r-', label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# Learning rate schedule
plt.subplot(1, 3, 3)
lrs = []
for param_group in optimizer.param_groups:
    lrs.append(param_group['lr'])
    
if len(lrs) == 1:
    # If we only have the final LR, show the schedule conceptually
    total_steps_completed = len(tracker.train_losses) * len(train_loader)
    steps_range = np.linspace(0, total_steps_completed, 100)
    lr_schedule = []
    for step in steps_range:
        if step < WARMUP_STEPS:
            lr = LEARNING_RATE * (step / WARMUP_STEPS)
        else:
            progress = (step - WARMUP_STEPS) / (total_steps - WARMUP_STEPS)
            lr = LEARNING_RATE * (1 - progress)
        lr_schedule.append(lr)
    
    plt.plot(steps_range, lr_schedule, 'g-')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Training Steps')
    plt.ylabel('Learning Rate')
    plt.grid(True)

plt.tight_layout()
plt.show()

print("\nüìà Training curves plotted!")

üöÄ Starting ClinicalBERT training...

üìÖ Epoch 1/5
------------------------------


KeyboardInterrupt: 

## 8. Model Evaluation & Metrics

In [None]:
# Comprehensive model evaluation
print("üìä Evaluating ClinicalBERT model...")
print("="*50)

# Test set evaluation
test_loss, test_accuracy, test_predictions, test_labels = evaluate_model(
    model, test_loader, criterion, device
)

# Convert predictions to probabilities and classes
test_probs = torch.softmax(test_predictions, dim=1)
test_pred_classes = torch.argmax(test_predictions, dim=1)

# Convert to numpy for sklearn metrics
y_true = test_labels.numpy()
y_pred = test_pred_classes.numpy()
y_probs = test_probs[:, 1].numpy()  # Probabilities for positive class

# Calculate comprehensive metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='binary')
recall = recall_score(y_true, y_pred, average='binary')
f1 = f1_score(y_true, y_pred, average='binary')
roc_auc = roc_auc_score(y_true, y_probs)

print(f"üéØ Test Set Results:")
print(f"   Accuracy: {accuracy:.4f}")
print(f"   Precision: {precision:.4f}")
print(f"   Recall: {recall:.4f}")
print(f"   F1-Score: {f1:.4f}")
print(f"   ROC-AUC: {roc_auc:.4f}")

# Detailed classification report
print(f"\nüìã Detailed Classification Report:")
class_names = ['General Research', 'Clinical Management']
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
print(f"\nüî¢ Confusion Matrix:")
print(cm)

# Visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Confusion Matrix Heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=axes[0,0])
axes[0,0].set_title('Confusion Matrix')
axes[0,0].set_xlabel('Predicted')
axes[0,0].set_ylabel('Actual')

# 2. ROC Curve
fpr, tpr, thresholds = roc_curve(y_true, y_probs)
axes[0,1].plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
axes[0,1].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
axes[0,1].set_xlim([0.0, 1.0])
axes[0,1].set_ylim([0.0, 1.05])
axes[0,1].set_xlabel('False Positive Rate')
axes[0,1].set_ylabel('True Positive Rate')
axes[0,1].set_title('ROC Curve')
axes[0,1].legend(loc="lower right")
axes[0,1].grid(True)

# 3. Prediction Confidence Distribution
axes[1,0].hist(y_probs[y_true == 0], bins=20, alpha=0.7, label='Class 0', color='red')
axes[1,0].hist(y_probs[y_true == 1], bins=20, alpha=0.7, label='Class 1', color='blue')
axes[1,0].set_xlabel('Prediction Confidence (Probability)')
axes[1,0].set_ylabel('Frequency')
axes[1,0].set_title('Prediction Confidence Distribution')
axes[1,0].legend()
axes[1,0].grid(True)

# 4. Metrics Comparison
metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']
metrics_values = [accuracy, precision, recall, f1, roc_auc]
bars = axes[1,1].bar(metrics_names, metrics_values, color=['skyblue', 'lightgreen', 'lightcoral', 'gold', 'plum'])
axes[1,1].set_ylim([0, 1])
axes[1,1].set_title('Model Performance Metrics')
axes[1,1].set_ylabel('Score')

# Add value labels on bars
for bar, value in zip(bars, metrics_values):
    axes[1,1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                   f'{value:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Save results to JSON
results = {
    'model': 'ClinicalBERT',
    'dataset': 'Type_2_diabetes.csv',
    'test_samples': len(y_true),
    'metrics': {
        'accuracy': float(accuracy),
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1),
        'roc_auc': float(roc_auc)
    },
    'confusion_matrix': cm.tolist(),
    'class_names': class_names
}

import json
import os

# Create results directory
os.makedirs('results', exist_ok=True)

with open('results/text_metrics.json', 'w') as f:
    json.dump(results, f, indent=2)

print(f"\nüíæ Results saved to 'results/text_metrics.json'")
print(f"‚úÖ Model evaluation complete!")

## 9. Model Explainability with LIME

In [None]:
# Model explainability using LIME
print("üîç Setting up model explainability with LIME...")

# Create prediction function for LIME
def predict_proba_fn(texts):
    """Prediction function for LIME explainer"""
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for text in texts:
            # Tokenize
            encoding = tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=MAX_LENGTH,
                return_tensors='pt'
            )
            
            # Move to device
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            
            # Get prediction
            outputs = model(input_ids, attention_mask)
            probs = torch.softmax(outputs, dim=1)
            predictions.append(probs.cpu().numpy()[0])
    
    return np.array(predictions)

# Initialize LIME explainer
try:
    explainer = LimeTextExplainer(
        class_names=['General Research', 'Clinical Management'],
        mode='classification'
    )
    print("‚úÖ LIME explainer initialized")
    
    # Create directory for attention plots
    os.makedirs('results/attention_plots', exist_ok=True)
    
    # Select sample texts for explanation
    sample_indices = [0, 1, 2]  # First few test samples
    explanations = []
    
    print(f"\nüî¨ Generating explanations for {len(sample_indices)} samples...")
    
    for i, idx in enumerate(sample_indices):
        sample_text = X_test[idx]
        true_label = y_test[idx]
        predicted_probs = predict_proba_fn([sample_text])[0]
        predicted_class = np.argmax(predicted_probs)
        
        print(f"\nüìù Sample {i+1}:")
        print(f"   True Label: {true_label} ({'Clinical Management' if true_label == 1 else 'General Research'})")
        print(f"   Predicted: {predicted_class} ({'Clinical Management' if predicted_class == 1 else 'General Research'})")
        print(f"   Confidence: {predicted_probs[predicted_class]:.3f}")
        print(f"   Text preview: {sample_text[:150]}...")
        
        # Generate LIME explanation
        explanation = explainer.explain_instance(
            sample_text,
            predict_proba_fn,
            num_features=20,  # Top 20 most important words
            num_samples=1000  # Number of samples for LIME
        )
        
        explanations.append({
            'text': sample_text,
            'true_label': true_label,
            'predicted_class': predicted_class,
            'predicted_probs': predicted_probs.tolist(),
            'explanation': explanation
        })
        
        # Save explanation as HTML
        explanation.save_to_file(f'results/attention_plots/lime_explanation_{i+1}.html')
        
        # Show top contributing words
        print(f"   üéØ Top contributing words:")
        for word, weight in explanation.as_list()[:10]:
            direction = "‚Üí Clinical" if weight > 0 else "‚Üí General"
            print(f"      '{word}': {weight:.3f} {direction}")
    
    print(f"\nüíæ LIME explanations saved to 'results/attention_plots/'")
    
    # Visualize feature importance for the first sample
    if explanations:
        explanation = explanations[0]['explanation']
        
        # Get feature weights
        features = explanation.as_list()
        words = [f[0] for f in features[:15]]  # Top 15 words
        weights = [f[1] for f in features[:15]]
        
        # Create visualization
        plt.figure(figsize=(12, 8))
        colors = ['red' if w < 0 else 'blue' for w in weights]
        bars = plt.barh(range(len(words)), weights, color=colors, alpha=0.7)
        
        plt.yticks(range(len(words)), words)
        plt.xlabel('Feature Importance')
        plt.title('LIME Explanation: Word Importance for Classification\n(Blue ‚Üí Clinical Management, Red ‚Üí General Research)')
        plt.grid(axis='x', alpha=0.3)
        
        # Add value labels
        for i, (bar, weight) in enumerate(zip(bars, weights)):
            plt.text(weight + (0.01 if weight > 0 else -0.01), i, f'{weight:.3f}', 
                    va='center', ha='left' if weight > 0 else 'right')
        
        plt.tight_layout()
        plt.savefig('results/attention_plots/lime_feature_importance.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print("üìä Feature importance plot saved!")

except ImportError:
    print("‚ö†Ô∏è LIME not available. Skipping explainability analysis.")
    print("üí° Install LIME with: pip install lime")

# Attention visualization (simplified version without LIME)
def analyze_important_words(model, tokenizer, text, label):
    """Simple word importance analysis using gradients"""
    model.eval()
    
    # Tokenize
    encoding = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=MAX_LENGTH,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Enable gradients for input embeddings
    embeddings = model.bert.embeddings.word_embeddings(input_ids)
    embeddings.requires_grad_(True)
    
    # Forward pass
    outputs = model.bert(inputs_embeds=embeddings, attention_mask=attention_mask)
    cls_output = outputs.last_hidden_state[:, 0, :]
    logits = model.classifier(model.dropout(cls_output))
    
    # Get gradients
    target_class_logit = logits[0, label]
    target_class_logit.backward()
    
    # Calculate importance scores
    gradients = embeddings.grad
    importance_scores = torch.norm(gradients, dim=-1).squeeze()
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze())
    
    # Combine importance with tokens
    token_importance = []
    for token, score in zip(tokens, importance_scores):
        if token not in ['[PAD]', '[CLS]', '[SEP]']:
            token_importance.append((token, score.item()))
    
    # Sort by importance
    token_importance.sort(key=lambda x: x[1], reverse=True)
    
    return token_importance

print(f"\nüß† Alternative gradient-based word importance analysis:")
print("(This shows which words the model pays attention to)")

# Analyze a few samples
for i, idx in enumerate(sample_indices[:2]):
    sample_text = X_test[idx]
    true_label = y_test[idx]
    
    print(f"\nüìù Sample {i+1} - Important words:")
    try:
        word_importance = analyze_important_words(model, tokenizer, sample_text, true_label)
        for word, importance in word_importance[:10]:
            print(f"   '{word}': {importance:.4f}")
    except Exception as e:
        print(f"   ‚ö†Ô∏è Error in gradient analysis: {e}")

print("\n‚úÖ Explainability analysis complete!")

## 10. Inference Script & Model Deployment

In [None]:
# Save trained model and create inference script
print("üíæ Saving model and creating inference script...")

# Save model
os.makedirs('models', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'model_name': MODEL_NAME,
        'num_classes': 2,
        'max_length': MAX_LENGTH,
        'dropout_rate': 0.3
    },
    'tokenizer_name': MODEL_NAME,
    'class_names': ['General Research', 'Clinical Management'],
    'preprocessing_info': {
        'medical_abbreviations': True,
        'lowercase': True,
        'max_length': MAX_LENGTH
    }
}, 'models/clinical_bert_diabetes_classifier.pth')

print("‚úÖ Model saved to 'models/clinical_bert_diabetes_classifier.pth'")

# Create inference script
inference_script = '''
"""
ClinicalBERT Diabetes Text Classifier - Inference Script
========================================================

This script provides easy-to-use functions for predicting diabetes relevance
from medical text using the trained ClinicalBERT model.

Usage:
    from text_model import DiabetesTextClassifier
    
    classifier = DiabetesTextClassifier('models/clinical_bert_diabetes_classifier.pth')
    
    text = "Patient presents with hyperglycemia and requires insulin therapy..."
    prediction = classifier.predict(text)
    print(f"Prediction: {prediction['label']} (confidence: {prediction['confidence']:.3f})")
"""

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import re
import numpy as np
import json

class ClinicalBERTClassifier(nn.Module):
    def __init__(self, model_name, num_classes=2, dropout_rate=0.3):
        super(ClinicalBERTClassifier, self).__init__()
        
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

class DiabetesTextClassifier:
    def __init__(self, model_path, device=None):
        """
        Initialize the diabetes text classifier
        
        Args:
            model_path (str): Path to the saved model file
            device (str): Device to run inference on ('cuda' or 'cpu')
        """
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_path = model_path
        
        # Load model checkpoint
        checkpoint = torch.load(model_path, map_location=self.device)
        
        # Extract configuration
        self.config = checkpoint['model_config']
        self.class_names = checkpoint['class_names']
        self.max_length = self.config['max_length']
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint['tokenizer_name'])
        
        # Initialize and load model
        self.model = ClinicalBERTClassifier(
            self.config['model_name'],
            self.config['num_classes'],
            self.config['dropout_rate']
        )
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
        
        print(f"‚úÖ Model loaded successfully on {self.device}")
        print(f"   Classes: {self.class_names}")
    
    def preprocess_text(self, text):
        """Preprocess medical text"""
        if not text:
            return ""
        
        text = str(text)
        
        # Remove de-identified placeholders
        text = re.sub(r'\\[\\*\\*[^]]*\\*\\*\\]', '', text)
        
        # Remove extra whitespace
        text = re.sub(r'\\s+', ' ', text).strip()
        
        # Preserve medical abbreviations
        medical_abbrevs = [
            'HTN', 'DM', 'T2DM', 'T1DM', 'CAD', 'CHF', 'COPD', 'CKD', 'CVD',
            'MI', 'PE', 'DVT', 'UTI', 'ICU', 'ER', 'OR', 'IV', 'PO', 'NPO',
            'BID', 'TID', 'QID', 'PRN', 'STAT', 'HbA1c', 'BMI', 'BP', 'HR',
            'RR', 'O2', 'CO2', 'EKG', 'ECG', 'CBC', 'BUN', 'GFR', 'ALT', 'AST'
        ]
        
        # Temporarily replace abbreviations
        abbrev_map = {}
        for i, abbrev in enumerate(medical_abbrevs):
            if abbrev in text:
                placeholder = f"__ABBREV_{i}__"
                abbrev_map[placeholder] = abbrev
                text = text.replace(abbrev, placeholder)
        
        # Convert to lowercase
        text = text.lower()
        
        # Restore abbreviations
        for placeholder, abbrev in abbrev_map.items():
            text = text.replace(placeholder, abbrev)
        
        # Clean punctuation
        text = re.sub(r'[^\\w\\s.,()-]', ' ', text)
        text = re.sub(r'\\s+', ' ', text).strip()
        
        return text
    
    def predict(self, text, return_probabilities=False):
        """
        Predict diabetes relevance for input text
        
        Args:
            text (str): Input medical text
            return_probabilities (bool): Whether to return class probabilities
            
        Returns:
            dict: Prediction results with label, confidence, and optionally probabilities
        """
        # Preprocess text
        processed_text = self.preprocess_text(text)
        
        if not processed_text:
            return {
                'label': 'Unknown',
                'confidence': 0.0,
                'probabilities': [0.5, 0.5] if return_probabilities else None,
                'error': 'Empty or invalid text'
            }
        
        # Tokenize
        encoding = self.tokenizer(
            processed_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Move to device
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # Predict
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0, predicted_class].item()
        
        result = {
            'label': self.class_names[predicted_class],
            'confidence': confidence,
            'predicted_class': predicted_class
        }
        
        if return_probabilities:
            result['probabilities'] = probabilities[0].cpu().numpy().tolist()
        
        return result
    
    def predict_batch(self, texts, batch_size=8):
        """
        Predict for multiple texts
        
        Args:
            texts (list): List of input texts
            batch_size (int): Batch size for processing
            
        Returns:
            list: List of prediction results
        """
        results = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_results = []
            
            for text in batch_texts:
                result = self.predict(text, return_probabilities=True)
                batch_results.append(result)
            
            results.extend(batch_results)
        
        return results

# Example usage and testing
def main():
    """Example usage of the classifier"""
    
    # Sample medical texts for testing
    test_texts = [
        "Patient presents with elevated HbA1c of 9.2% and requires insulin therapy adjustment. Blood glucose monitoring shows persistent hyperglycemia despite metformin treatment.",
        "This systematic review examines the prevalence of cardiovascular disease in different populations across multiple cohort studies.",
        "Clinical trial demonstrates efficacy of GLP-1 agonists in reducing HbA1c levels by 1.2% compared to placebo in patients with T2DM.",
        "Epidemiological analysis of risk factors associated with metabolic syndrome in the general population."
    ]
    
    # Initialize classifier
    try:
        classifier = DiabetesTextClassifier('models/clinical_bert_diabetes_classifier.pth')
        
        print("\\nüß™ Testing classifier with sample texts:")
        print("="*60)
        
        for i, text in enumerate(test_texts, 1):
            result = classifier.predict(text, return_probabilities=True)
            
            print(f"\\nSample {i}:")
            print(f"Text: {text[:100]}...")
            print(f"Prediction: {result['label']}")
            print(f"Confidence: {result['confidence']:.3f}")
            print(f"Probabilities: {[f'{p:.3f}' for p in result['probabilities']]}")
        
        print("\\n‚úÖ Classifier testing complete!")
        
    except FileNotFoundError:
        print("‚ùå Model file not found. Please train the model first.")
    except Exception as e:
        print(f"‚ùå Error loading classifier: {e}")

if __name__ == "__main__":
    main()
'''

# Save inference script
os.makedirs('src', exist_ok=True)
with open('src/text_model.py', 'w') as f:
    f.write(inference_script)

print("‚úÖ Inference script saved to 'src/text_model.py'")

# Test the inference script
print("\nüß™ Testing inference script...")
sample_text = "Patient with T2DM presents with HbA1c of 8.5% requiring insulin therapy adjustment."

try:
    # Quick test of prediction function
    def quick_predict(text):
        processed = preprocess_medical_text(text)
        
        encoding = tokenizer(
            processed,
            truncation=True,
            padding='max_length',
            max_length=MAX_LENGTH,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        
        model.eval()
        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0, predicted_class].item()
        
        class_names = ['General Research', 'Clinical Management']
        return {
            'text': text[:100] + "..." if len(text) > 100 else text,
            'label': class_names[predicted_class],
            'confidence': confidence
        }
    
    result = quick_predict(sample_text)
    print(f"‚úÖ Test successful!")
    print(f"   Text: {result['text']}")
    print(f"   Prediction: {result['label']}")
    print(f"   Confidence: {result['confidence']:.3f}")
    
except Exception as e:
    print(f"‚ö†Ô∏è Test error: {e}")

print("\nüéâ Model deployment ready!")
print("\nüìÇ Generated Files:")
print("   üìî notebooks/diabetes_text_classification.ipynb ‚Üí Training pipeline")
print("   üß† models/clinical_bert_diabetes_classifier.pth ‚Üí Trained model")
print("   üêç src/text_model.py ‚Üí Inference script")
print("   üìä results/text_metrics.json ‚Üí Evaluation results")
print("   üìà results/attention_plots/ ‚Üí Explainability outputs")

print("\nüöÄ Usage Instructions:")
print("1. Import the classifier: from src.text_model import DiabetesTextClassifier")
print("2. Load model: classifier = DiabetesTextClassifier('models/clinical_bert_diabetes_classifier.pth')")
print("3. Make predictions: result = classifier.predict('Your medical text here')")
print("4. Get results: Prediction label + confidence + highlighted important words")