In [34]:
import torch
import pandas as pd
import os
import glob
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import densenet121
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np


In [58]:
device = torch.device("cuda:0")
print(device)

cuda:0


In [53]:
class PneumoniaDataset(Dataset):
    def __init__(self,data_dir, label_file, transform=None):

        self.data_dir = data_dir
        self.transform = transform
        self.labels_df = pd.read_excel(label_file)
        #if label contains Pneumothorax = 1 else 0 
        self.labels_df['Pneumothorax'] = self.labels_df['Finding Labels'].apply(lambda x: 1 if 'Pneumothorax' in x.split('|') else 0)
        self.image_paths = {os.path.basename(x): x for x in glob.glob(os.path.join(data_dir, '*', 'images', '*.png'))}
        self.labels_df['path'] = self.labels_df['Image Index'].map(self.image_paths.get)
        self.labels_df.dropna(subset=['path'], inplace=True)

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        img_name = self.labels_df.iloc[idx]['path']
        image = Image.open(img_name).convert('RGB')
        label = self.labels_df.iloc[idx]['Pneumothorax']
        if self.transform:
            image = self.transform(image)
        return image, label

def get_transforms():  
#resize the image size to 224 x 224 pixels
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform

def get_data_loaders(data_dir, label_file, batch_size=32, val_split=0.2,  test_split=0.1):

    transform = get_transforms()
    dataset = PneumoniaDataset(data_dir=data_dir, label_file=label_file, transform=transform)
    
    # Calculate split sizes
    val_size = int(val_split * len(dataset))
    test_size = int(test_split * len(dataset))
    train_size = len(dataset) - val_size - test_size
    
    # Split the dataset into training, validation, and test sets
    train_dataset, val_test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size + test_size])
    val_dataset, test_dataset = torch.utils.data.random_split(val_test_dataset, [val_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_dataset, val_dataset,train_loader, val_loader,test_dataset,test_loader




In [57]:

train_dataset, val_dataset,train_loader, val_loader,test_dataset, test_loader= get_data_loaders(data_dir='../../raw_data/archive/', label_file='../../raw_data/archive/CXR8-selected/micro.xlsx')
print(f"Training Dataset Size: {len(train_dataset)}")
print(f"Validation Dataset Size: {len(val_dataset)}")
print(f"Test Dataset Size: {len(test_dataset)}")

def count_batches(data_loader):
    return sum(1 for _ in data_loader)

print(f"Training DataLoader Batches: {count_batches(train_loader)}")
print(f"Validation DataLoader Batches: {count_batches(val_loader)}")
print(f"Test DataLoader Batches: {count_batches(test_loader)}")

def inspect_loader(data_loader):
    images, labels = next(iter(data_loader))
    print(f"Batch Images Shape: {images.shape}")
    print(f"Batch Labels Shape: {labels.shape}")
    print(f"Batch Labels: {labels}")

# Inspect each DataLoader
print("Training DataLoader:")
inspect_loader(train_loader)

print("\nValidation DataLoader:")
inspect_loader(val_loader)

print("\nTest DataLoader:")
inspect_loader(test_loader)

Training Dataset Size: 3501
Validation Dataset Size: 999
Test Dataset Size: 499
Training DataLoader Batches: 110
Validation DataLoader Batches: 32
Test DataLoader Batches: 16
Training DataLoader:
Batch Images Shape: torch.Size([32, 3, 224, 224])
Batch Labels Shape: torch.Size([32])
Batch Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

Validation DataLoader:
Batch Images Shape: torch.Size([32, 3, 224, 224])
Batch Labels Shape: torch.Size([32])
Batch Labels: tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

Test DataLoader:
Batch Images Shape: torch.Size([32, 3, 224, 224])
Batch Labels Shape: torch.Size([32])
Batch Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])


In [49]:
#DensenNet121
model = densenet121(weights='DenseNet121_Weights.DEFAULT')
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(num_ftrs, 500),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(500, 1),
    nn.Sigmoid())

In [51]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

def train_model(model, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        model.train()  # Set model to training mode
        
        total_loss = 0.0
        total_correct = 0
        
        for inputs, labels in train_loader:
            # Clear previous gradients
            optimizer.zero_grad()
            # Forward pass
            outputs = model(inputs)
            # Calculate loss
            loss = criterion(outputs, labels.unsqueeze(1).float())
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            preds = torch.round(outputs)
            total_correct += torch.sum(preds == labels.unsqueeze(1).data)

        epoch_loss = total_loss / len(train_dataset)
        epoch_acc = total_correct.double() / len(train_dataset)

        print(f'Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss}, Acc: {epoch_acc}')

train_model(model, criterion, optimizer, num_epochs=10)


Epoch 1/10
----------
Epoch 0/9, Loss: 0.07990026066329259, Acc: 0.9885746929448729
Epoch 2/10
----------
Epoch 1/9, Loss: 0.043374007464017025, Acc: 0.9885746929448729
Epoch 3/10
----------
Epoch 2/9, Loss: 0.023513220813895754, Acc: 0.9885746929448729
Epoch 4/10
----------
Epoch 3/9, Loss: 0.007953736617718551, Acc: 0.9985718366181091
Epoch 5/10
----------
Epoch 4/9, Loss: 0.022046017459793493, Acc: 0.9925735504141674
Epoch 6/10
----------
Epoch 5/9, Loss: 0.013139387878574169, Acc: 0.9962867752070836
Epoch 7/10
----------
Epoch 6/9, Loss: 0.0018234844873002253, Acc: 0.9994287346472437
Epoch 8/10
----------
Epoch 7/9, Loss: 0.0006018527172530728, Acc: 1.0
Epoch 9/10
----------
Epoch 8/9, Loss: 0.0002889940288236307, Acc: 1.0
Epoch 10/10
----------
Epoch 9/9, Loss: 0.00014511467454472095, Acc: 1.0


In [60]:
def test_model(model, data_loader):
    model.eval()  # Set model to evaluation mode
    true_labels = []
    pred_labels = []
    
    # No need to track gradients for testing
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            pred_labels.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels)
    recall = recall_score(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels)

    
    print(f'Test Accuracy: {accuracy * 100:.2f}%')
    print(f'Precision: {precision:.2f}')
    print(f'Recall: {recall:.2f}')
    print(f'F1 Score: {f1:.2f}')
    
    return accuracy, precision, recall, f1

# Call the evaluate function
test_model(model, test_loader)

Test Accuracy: 98.40%
Precision: 0.00
Recall: 0.00
F1 Score: 0.00


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


(0.9839679358717435, 0.0, 0.0, 0.0)