# Binary Classification of Chest X-ray Images (Apple Silicon Optimized)

This notebook supports both TensorFlow/Keras and PyTorch training. Set the `PYTORCH_TRAINING` flag to choose the backend.

- Dataset: ChestX-ray14 (CXR8)
- Binary classification: Normal vs Abnormal


In [None]:
# --- Flags ---
PYTORCH_TRAINING = True  # Set to True to use PyTorch instead of Keras


In [14]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, f1_score
import warnings
np.random.seed(42)
warnings.filterwarnings("ignore", message=".*iCCP: profile 'ICC Profile': 'GRAY': Gray color space not permitted on RGB PNG*")


In [15]:
# --- Data Loading and Exploration ---
data_path = '../datasets/CXR8/'
labels_df = pd.read_csv(os.path.join(data_path, 'Data_Entry_2017_v2020.csv'))
labels_df['binary_label'] = (labels_df['Finding Labels'] != 'No Finding').astype(int)

# Create a balanced dataset
normal_samples = labels_df[labels_df['binary_label'] == 0]
abnormal_samples = labels_df[labels_df['binary_label'] == 1]
min_samples = min(len(normal_samples), len(abnormal_samples))
balanced_normal = normal_samples.sample(n=min_samples, random_state=42)
balanced_abnormal = abnormal_samples.sample(n=min_samples, random_state=42)
balanced_df = pd.concat([balanced_normal, balanced_abnormal]).reset_index(drop=True)

# Split the data
train_df, temp_df = train_test_split(balanced_df, test_size=0.3, random_state=42, stratify=balanced_df['binary_label'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['binary_label'])

image_dir = os.path.join(data_path, 'images/extracted_images')


## --- PyTorch Training ---

This block runs if `PYTORCH_TRAINING = True`.


In [16]:
if PYTORCH_TRAINING:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    from tqdm import tqdm

    class CXRDataset(Dataset):
        def __init__(self, df, img_dir, transform=None):
            self.df = df.reset_index(drop=True)
            self.img_dir = img_dir
            self.transform = transform
        def __len__(self):
            return len(self.df)
        def __getitem__(self, idx):
            img_name = self.df.loc[idx, 'Image Index']
            label = self.df.loc[idx, 'binary_label']
            img_path = os.path.join(self.img_dir, img_name)
            image = Image.open(img_path).convert('L')
            if self.transform:
                image = self.transform(image)
            return image, torch.tensor(label, dtype=torch.float32)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    train_dataset = CXRDataset(train_df, image_dir, transform)
    val_dataset = CXRDataset(val_df, image_dir, transform)
    test_dataset = CXRDataset(test_df, image_dir, transform)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
            self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
            self.fc1 = nn.Linear(64 * 28 * 28, 64)
            self.dropout1 = nn.Dropout(0.5)
            self.dropout2 = nn.Dropout(0.3)
            self.fc2 = nn.Linear(64, 1)
        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = self.pool(torch.relu(self.conv2(x)))
            x = self.pool(torch.relu(self.conv3(x)))
            x = x.view(x.size(0), -1)
            x = torch.relu(self.fc1(x))
            x = self.dropout1(x)
            x = self.dropout2(x)
            x = torch.sigmoid(self.fc2(x))
            return x.squeeze(1)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleCNN().to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    best_val_loss = float('inf')
    epochs = 20
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} - Training'):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} - Validation'):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
        val_loss /= len(val_loader.dataset)
        print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print('Saved best model to best_model.pth')
    print('PyTorch training complete.')

    # Test evaluation (optional)
    model.load_state_dict(torch.load('best_model.pth', map_location=device))
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f'Test Accuracy: {correct/total:.4f}')


Epoch 1/20 - Training:   0%|          | 0/2265 [00:00<?, ?it/s]

Epoch 1/20 - Training:  39%|███▊      | 874/2265 [12:54<20:32,  1.13it/s]


KeyboardInterrupt: 