In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.applications.densenet import preprocess_input
import matplotlib.pyplot as plt
import os
from tensorflow.keras import layers, models

In [None]:
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter

# --- 1. Configuration ---

# This is your *original* folder with 91k+ images
SOURCE_DATA_DIR = r'C:\Users\SUBRAT\MAFSL PROJECT\processed_data\processed_data'

# This is the *new* folder where the 80/10/10 split will be created
TARGET_DATA_DIR = 'data_split/'

# Define the split ratios
VAL_SIZE = 0.1  # 10% for validation
TEST_SIZE = 0.1 # 10% for testing
# (The remaining 80% will be for training)

RANDOM_STATE = 42 # Ensures the split is repeatable

# --- 2. The Splitting Function ---

def create_stratified_split(source_dir, target_dir):
    """
    Finds all images in the source_dir, splits them into 80/10/10
    train/val/test sets with stratification, and copies them to the
    new target_dir.
    """
    
    print(f"--- Starting Stratified Split ---")
    
    if os.path.exists(target_dir):
        print(f"Removing existing directory: {target_dir}")
        shutil.rmtree(target_dir)

    # 1. Find all image paths and their corresponding labels
    all_filepaths = []
    all_labels = []
    
    # We must use sorted() to ensure class order is consistent
    # (e.g., 'COVID' is always 0, 'NORMAL' is 1, etc.)
    classes = sorted([d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))])
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    
    if not classes:
        print(f"Error: No subdirectories found in {source_dir}. Aborting.")
        return False
        
    print(f"Found {len(classes)} classes: {classes}")

    for cls in classes:
        class_dir = os.path.join(source_dir, cls)
        # Check if it's a directory
        if not os.path.isdir(class_dir):
            continue
            
        for img_name in os.listdir(class_dir):
            img_path = os.path.join(class_dir, img_name)
            
            # Check if it's a file
            if os.path.isfile(img_path):
                all_filepaths.append(img_path)
                all_labels.append(class_to_idx[cls])

    print(f"\nTotal images found: {len(all_filepaths)}")
    if not all_filepaths:
        print("Error: No image files found. Check your 'processed_data' folder.")
        return False
        
    print(f"Original Class distribution: {Counter([classes[i] for i in all_labels])}")

    # 2. Create the first split (train_val vs. test)
    # We stratify on the labels to keep class ratios the same
    train_val_files, test_files, train_val_labels, test_labels = train_test_split(
        all_filepaths, all_labels, 
        test_size=TEST_SIZE, 
        stratify=all_labels, 
        random_state=RANDOM_STATE
    )

    # 3. Create the second split (train vs. val)
    # We adjust the validation size relative to the 90% that's left
    val_size_adjusted = VAL_SIZE / (1.0 - TEST_SIZE)
    
    train_files, val_files, train_labels, val_labels = train_test_split(
        train_val_files, train_val_labels, 
        test_size=val_size_adjusted, 
        stratify=train_val_labels, 
        random_state=RANDOM_STATE
    )
    
    print(f"\nSplitting data into:")
    print(f"  Train set: {len(train_files)} images")
    print(f"  Val set:   {len(val_files)} images")
    print(f"  Test set:  {len(test_files)} images")

    # 4. Copy files to new directories (train/, val/, test/)
    datasets = {
        'train': (train_files, train_labels),
        'val': (val_files, val_labels),
        'test': (test_files, test_labels)
    }

    for split_name, (files, labels) in datasets.items():
        split_path = os.path.join(target_dir, split_name)
        
        for i, filepath in enumerate(files):
            # Get the class name
            class_name = classes[labels[i]]
            
            # Create target class dir (e.g., data_split/train/COVID/)
            target_class_dir = os.path.join(split_path, class_name)
            os.makedirs(target_class_dir, exist_ok=True)
            
            # Copy file
            shutil.copy(filepath, target_class_dir)
            
    print(f"\nSuccessfully created stratified split at: {target_dir}")
    print("--- Data Splitting Complete ---")
    return True

# --- 3. Run the Function ---
if __name__ == "__main__":
    # Check if scikit-learn is installed
    try:
        from sklearn.model_selection import train_test_split
    except ImportError:
        print("Error: scikit-learn is not installed.")
        print("Please install it by running: pip install scikit-learn")
        exit()

    if not os.path.exists(SOURCE_DATA_DIR):
        print(f"Error: Source directory not found: {SOURCE_DATA_DIR}")
        print("Please make sure your data is in a folder named 'processed_data' in the same directory.")
    else:
        create_stratified_sp
        lit(SOURCE_DATA_DIR, TARGET_DATA_DIR)

--- Starting Stratified Split ---
Found 4 classes: ['COVID', 'Normal', 'Pneumonia', 'Tuberculosis']

Total images found: 104329
Original Class distribution: Counter({'Normal': 91225, 'Pneumonia': 8788, 'COVID': 3616, 'Tuberculosis': 700})

Splitting data into:
  Train set: 83463 images
  Val set:   10433 images
  Test set:  10433 images

Successfully created stratified split at: data_split/
--- Data Splitting Complete ---


In [6]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim/
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchxrayvision as xrv
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from tqdm import tqdm
from collections import Counter
import warnings

# --- 0. Configuration ---
# This is the folder you just created with split_data.py
SPLIT_DATA_DIR = 'data_split/' 

# Training Hyperparameters
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4
MODEL_SAVE_PATH = 'best_imbalanced_model.pth'

# ==============================================================================
# STAGE 2: IMBALANCED-AWARE TRAINING
# ==============================================================================

def get_xrv_transforms():
    """
    Get the specific transforms for torchxrayvision models (1-channel).
    """
    # This is the normalization specified by torchxrayvision
    XRV_MEAN = [0.5081]
    XRV_STD = [0.0893]
    
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1), # XRV models expect 1 channel
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=XRV_MEAN, std=XRV_STD)
    ])
    
    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=XRV_MEAN, std=XRV_STD)
    ])
    
    return train_transform, val_test_transform

class XRVTransferModel(nn.Module):
    """
    A wrapper for the torchxrayvision model to add a custom classifier head.
    """
    def __init__(self, num_classes):
        super(XRVTransferModel, self).__init__()
        # Load pre-trained backbone
        # This model is pre-trained on chest x-rays, which is perfect
        model = xrv.models.DenseNet(weights="densenet121-res224-all")
        self.backbone = model.features
        
        # --- Freeze the backbone ---
        # We will only train the final classifier layer
        for name, param in self.backbone.named_parameters():
              if 'denseblock4' not in name:
                      param.requires_grad = False

             
        # Add a new classifier head
        # DenseNet-121 output is 1024 features
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5), # Regularization
            nn.Linear(1024, num_classes) # Output layer
        )

    def forward(self, x):
        features = self.backbone(x)
        pooled = self.pooling(features).view(features.size(0), -1)
        output = self.classifier(pooled)
        return output

def create_weighted_sampler(dataset):
    """
    Creates a WeightedRandomSampler to handle class imbalance.
    """
    print("\nCalculating sampler weights for training set...")
    class_counts = Counter(dataset.targets)
    
    # Sort counts by class index (0, 1, 2, 3...)
    class_counts = [class_counts.get(i, 0) for i in range(len(dataset.classes))]
    print(f"  Class counts: {class_counts}")
    
    # Calculate weight per class (1 / num_samples)
    class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
    print(f"  Class weights: {class_weights}")
    
    # Assign a weight to every single sample in the dataset
    sample_weights = [class_weights[label] for label in dataset.targets]
    
    # Create the sampler
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True # Oversampling: draw samples with replacement
    )
    print("WeightedRandomSampler created.")
    return sampler

def run_training():
    """
    Main function to run the entire training and evaluation pipeline.
    """
    print("--- Starting Imbalanced-Aware Training ---")
    
    # --- 1. Load Datasets ---
    train_transform, val_test_transform = get_xrv_transforms()
    
    train_dataset = ImageFolder(os.path.join(SPLIT_DATA_DIR, 'train'), transform=train_transform)
    val_dataset = ImageFolder(os.path.join(SPLIT_DATA_DIR, 'val'), transform=val_test_transform)
    test_dataset = ImageFolder(os.path.join(SPLIT_DATA_DIR, 'test'), transform=val_test_transform)
    
    # Get class names
    class_names = train_dataset.classes
    num_classes = len(class_names)
    print(f"\nTraining on {num_classes} classes: {class_names}")

    # --- 2. Create DataLoaders (with Sampler for train) ---
    train_sampler = create_weighted_sampler(train_dataset)
    
    # Note: shuffle=False because the sampler handles shuffling.
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        sampler=train_sampler, 
        num_workers=2
    )
    
    # Val and Test loaders should NOT be balanced. We want to test on the
    # real, imbalanced distribution.
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2
    )

    # --- 3. Initialize Model, Loss, Optimizer ---
    model = XRVTransferModel(num_classes=num_classes).to(DEVICE)
    
    # We could also use a weighted loss, but the sampler is often enough.
    # If performance is still bad, we can add weighted loss as well.
    criterion = nn.CrossEntropyLoss()
    
    # We are only training the head, as we froze the backbone
    optimizer = optim.Adam(
       filter(lambda p: p.requires_grad, model.parameters()), 
       lr=LEARNING_RATE
)

    # --- 4. Training & Validation Loop ---
    best_val_f1 = 0.0
    
    # Suppress zero-division warnings from sklearn
    warnings.filterwarnings('ignore', category=UserWarning, message='F-score is ill-defined')

    print(f"\nStarting training for {EPOCHS} epochs...")
    for epoch in range(EPOCHS):
        # --- Training ---
        model.train()
        train_loss = 0.0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        
        # CRITICAL: Use a weighted F1-score, not accuracy
        # 'average='weighted'' accounts for class imbalance in the F1 score
        val_f1_weighted = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        print(f"\nEpoch {epoch+1}/{EPOCHS}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss:   {avg_val_loss:.4f}")
        print(f"  Val F1 (Weighted): {val_f1_weighted:.4f}")
        
        # Save the best model based on F1 score
        if val_f1_weighted > best_val_f1:
            best_val_f1 = val_f1_weighted
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"  New best model saved to {MODEL_SAVE_PATH} (F1: {best_val_f1:.4f})")
            
    print("\nTraining complete.")

    # --- 5. Final Evaluation on Test Set ---
    print("\n--- FINAL EVALUATION ON TEST SET ---")
    
    # Load the *best* model we saved during training
    try:
        model.load_state_dict(torch.load(MODEL_SAVE_PATH))
        print(f"Loaded best model from {MODEL_SAVE_PATH}")
    except FileNotFoundError:
        print("Warning: No best model was saved. Evaluating last epoch model.")
        
    model.eval()
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    # Print the final report
    print("\nClassification Report (Test Set):")
    # This report is the most important output.
    # It shows the precision, recall, and f1-score for EACH class.
    print(classification_report(all_targets, all_preds, target_names=class_names, zero_division=0))
    
    print("\nConfusion Matrix (Test Set):")
    # Rows are (True Label), Columns are (Predicted Label)
    print(confusion_matrix(all_targets, all_preds))



if __name__ == "__main__":
    run_training()






--- Starting Imbalanced-Aware Training ---

Training on 4 classes: ['COVID', 'Normal', 'Pneumonia', 'Tuberculosis']

Calculating sampler weights for training set...
  Class counts: [2893, 72980, 7030, 560]
  Class weights: tensor([3.4566e-04, 1.3702e-05, 1.4225e-04, 1.7857e-03])
WeightedRandomSampler created.

Starting training for 10 epochs...


Epoch 1/10 [Train]: 100%|██████████| 2609/2609 [1:42:28<00:00,  2.36s/it]
Epoch 1/10 [Val]: 100%|██████████| 327/327 [10:14<00:00,  1.88s/it]



Epoch 1/10:
  Train Loss: 0.5707
  Val Loss:   0.4799
  Val F1 (Weighted): 0.8150
  New best model saved to best_imbalanced_model.pth (F1: 0.8150)


Epoch 2/10 [Train]: 100%|██████████| 2609/2609 [1:38:01<00:00,  2.25s/it]
Epoch 2/10 [Val]: 100%|██████████| 327/327 [10:19<00:00,  1.89s/it]



Epoch 2/10:
  Train Loss: 0.3428
  Val Loss:   0.4237
  Val F1 (Weighted): 0.8389
  New best model saved to best_imbalanced_model.pth (F1: 0.8389)


Epoch 3/10 [Train]: 100%|██████████| 2609/2609 [1:37:21<00:00,  2.24s/it]
Epoch 3/10 [Val]: 100%|██████████| 327/327 [10:07<00:00,  1.86s/it]



Epoch 3/10:
  Train Loss: 0.3025
  Val Loss:   0.4596
  Val F1 (Weighted): 0.8279


Epoch 4/10 [Train]: 100%|██████████| 2609/2609 [1:36:40<00:00,  2.22s/it]
Epoch 4/10 [Val]: 100%|██████████| 327/327 [10:10<00:00,  1.87s/it]



Epoch 4/10:
  Train Loss: 0.2833
  Val Loss:   0.3894
  Val F1 (Weighted): 0.8497
  New best model saved to best_imbalanced_model.pth (F1: 0.8497)


Epoch 5/10 [Train]: 100%|██████████| 2609/2609 [1:36:42<00:00,  2.22s/it]
Epoch 5/10 [Val]: 100%|██████████| 327/327 [10:00<00:00,  1.84s/it]



Epoch 5/10:
  Train Loss: 0.2699
  Val Loss:   0.3309
  Val F1 (Weighted): 0.8721
  New best model saved to best_imbalanced_model.pth (F1: 0.8721)


Epoch 6/10 [Train]: 100%|██████████| 2609/2609 [1:37:06<00:00,  2.23s/it]
Epoch 6/10 [Val]: 100%|██████████| 327/327 [11:34<00:00,  2.12s/it]



Epoch 6/10:
  Train Loss: 0.2601
  Val Loss:   0.3747
  Val F1 (Weighted): 0.8561


Epoch 7/10 [Train]: 100%|██████████| 2609/2609 [1:41:58<00:00,  2.34s/it]
Epoch 7/10 [Val]: 100%|██████████| 327/327 [10:19<00:00,  1.90s/it]



Epoch 7/10:
  Train Loss: 0.2475
  Val Loss:   0.4198
  Val F1 (Weighted): 0.8436


Epoch 8/10 [Train]: 100%|██████████| 2609/2609 [1:57:26<00:00,  2.70s/it]  
Epoch 8/10 [Val]: 100%|██████████| 327/327 [10:18<00:00,  1.89s/it]



Epoch 8/10:
  Train Loss: 0.2434
  Val Loss:   0.3403
  Val F1 (Weighted): 0.8690


Epoch 9/10 [Train]: 100%|██████████| 2609/2609 [1:59:04<00:00,  2.74s/it]    
Epoch 9/10 [Val]: 100%|██████████| 327/327 [11:37<00:00,  2.13s/it]



Epoch 9/10:
  Train Loss: 0.2383
  Val Loss:   0.4271
  Val F1 (Weighted): 0.8402


Epoch 10/10 [Train]: 100%|██████████| 2609/2609 [2:43:56<00:00,  3.77s/it]  
Epoch 10/10 [Val]: 100%|██████████| 327/327 [16:19<00:00,  3.00s/it]



Epoch 10/10:
  Train Loss: 0.2333
  Val Loss:   0.3539
  Val F1 (Weighted): 0.8652

Training complete.

--- FINAL EVALUATION ON TEST SET ---
Loaded best model from best_imbalanced_model.pth


Testing: 100%|██████████| 327/327 [16:22<00:00,  3.00s/it]



Classification Report (Test Set):
              precision    recall  f1-score   support

       COVID       0.84      0.91      0.88       362
      Normal       0.97      0.85      0.91      9122
   Pneumonia       0.35      0.78      0.48       879
Tuberculosis       0.88      0.96      0.92        70

    accuracy                           0.85     10433
   macro avg       0.76      0.88      0.80     10433
weighted avg       0.92      0.85      0.87     10433


Confusion Matrix (Test Set):
[[ 331   21    7    3]
 [  56 7797 1265    4]
 [   4  190  683    2]
 [   2    0    1   67]]
