## Example Baseline CNN for Binary Classification on MURA

### Library Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

### Project Utility Imports

In [2]:
from utils.mura_dataset import MURADataset
from utils.transforms import get_train_transforms, get_val_transforms

### Model Definition

In [3]:
class BaselineCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1) #first conv layer (input = 1 channel, output = 16 feature maps)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(32 * 56 * 56, 128) #fully connected input dims (32 channels * 56 * 56 (after 2x pooling from 224))
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # [B, 16, 112, 112]
        x = self.pool(F.relu(self.conv2(x))) # [B, 32, 56, 56]
        x = x.view(x.size(0), -1) #flatten for fc layer
        x = F.relu(self.fc1(x))
        return torch.sigmoid(self.fc2(x)) #ouputs prob between 0-1

### Function to Load Train, Test, and Val

In [4]:
def get_loaders(batch_size=32):
    train_dataset = MURADataset(
        csv_file="data/splits/train_labeled_studies_split.csv",
        transform=get_train_transforms(),
        root_dir="data/raw"
    )
    val_dataset = MURADataset(
        csv_file="data/splits/val_labeled_studies_split.csv",
        transform=get_val_transforms(),
        root_dir="data/raw"
    )
    test_dataset = MURADataset(
        csv_file="data/splits/valid_labeled_studies.csv",
        transform=get_val_transforms(),
        root_dir="data/raw"
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

### Training Func for One Epoch

In [5]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0
    for images, labels in tqdm(loader):
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device) #[B] -> [B,1]

        optimizer.zero_grad()
        outputs = model(images) #fwd pass
        loss = criterion(outputs, labels) #binary CE loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)

    return running_loss / len(loader.dataset) #avg loss for epoch

### Evaluation

In [6]:
def evaluate(model, loader, device):
    model.eval()
    all_labels, all_probs = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images).cpu().numpy().flatten()
            all_probs.extend(outputs)
            all_labels.extend(labels.numpy())

    preds = [1 if p >= 0.5 else 0 for p in all_probs] #binary preds w/ 0.5 threshold between classes
    acc = accuracy_score(all_labels, preds)
    auc = roc_auc_score(all_labels, all_probs)
    return acc, auc

In [7]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #if gpu avail
    train_loader, val_loader, test_loader = get_loaders() #load data

    model = BaselineCNN().to(device) #model init
    criterion = nn.BCELoss() #loss init
    optimizer = optim.Adam(model.parameters(), lr=1e-4) #optim init

    #train for 5 epochs
    for epoch in range(5):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_acc, val_auc = evaluate(model, val_loader, device)

        print(f"Epoch {epoch+1}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Acc:    {val_acc:.4f}, AUC: {val_auc:.4f}")
        
    #perform eval on the test set
    test_acc, test_auc = evaluate(model, test_loader, device)
    print(f"\nTest Accuracy: {test_acc:.4f}, AUC: {test_auc:.4f}")

if __name__ == "__main__":
    main()


100%|█████████████████████████████████████████| 921/921 [13:28<00:00,  1.14it/s]


Epoch 1
Train Loss: 0.6631
Val Acc:    0.6131, AUC: 0.6230


100%|█████████████████████████████████████████| 921/921 [17:50<00:00,  1.16s/it]


Epoch 2
Train Loss: 0.6432
Val Acc:    0.6188, AUC: 0.6457


100%|█████████████████████████████████████████| 921/921 [11:02<00:00,  1.39it/s]


Epoch 3
Train Loss: 0.6317
Val Acc:    0.6215, AUC: 0.6564


100%|█████████████████████████████████████████| 921/921 [11:45<00:00,  1.31it/s]


Epoch 4
Train Loss: 0.6221
Val Acc:    0.6347, AUC: 0.6655


100%|█████████████████████████████████████████| 921/921 [18:25<00:00,  1.20s/it]


Epoch 5
Train Loss: 0.6132
Val Acc:    0.6414, AUC: 0.6879

Test Accuracy: 0.6084, AUC: 0.6847
