In [14]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import os
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict
from torch import nn, optim
from torchvision import datasets, transforms, utils, models
import torchvision
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F

In [15]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device

device(type='mps')

In [16]:
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder

# Load metadata
metadata = pd.read_csv("/Users/yichi/Desktop/datathon/Data_Entry_2017.csv")

# Drop rows with missing age or gender (if any)
metadata = metadata.dropna(subset=["Patient Age", "Patient Gender"])

# Normalize age
scaler = StandardScaler()
metadata["age_scaled"] = scaler.fit_transform(metadata[["Patient Age"]])

# Encode gender as binary (Female=0, Male=1)
label_encoder = LabelEncoder()
metadata["gender_encoded"] = label_encoder.fit_transform(metadata["Patient Gender"])

# Optionally, store metadata in a dictionary for fast access
patient_info = {
    row["Image Index"]: (row["age_scaled"], row["gender_encoded"])
    for _, row in metadata.iterrows()
}
print(f"Metadata loaded for {len(patient_info)} images.")


Metadata loaded for 112120 images.


In [17]:
# Custom dataset class with early fusion
from torch.utils.data import Dataset
from PIL import Image
import torch
import os

class ChestXrayDataset(Dataset):
    def __init__(self, image_dir, image_list, labels, transform=None):
        self.image_dir = image_dir
        self.image_list = image_list
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # Early fusion: Get age and gender
        age, gender = patient_info.get(img_name, (0.0, 0))  # fallback values

        # Convert to tensor
        extra_features = torch.tensor([age, gender], dtype=torch.float32)

        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, extra_features, label


In [18]:
class DenseNet121(nn.Module):
    """Model modified.

    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.

    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x

In [19]:
def compute_AUCs(gt, pred):
    """Computes Area Under the Curve (AUC) from prediction scores.

    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.

    Returns:
        List of AUROCs of all classes.
    """
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

In [20]:
import torchvision.transforms as transforms

TRAIN_LIST = "/Users/yichi/Desktop/datathon/train_list.txt"
IMAGE_DIR = "/Users/yichi/Desktop/datathon/images"

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset = ChestXrayDataSet(IMAGE_DIR, TRAIN_LIST, transform=data_transforms)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)

In [10]:
images, labels = next(iter(trainloader))
print("Images shape:", images.shape) 
print("Labels shape:", labels.shape)

Images shape: torch.Size([64, 3, 224, 224])
Labels shape: torch.Size([64, 14])


In [21]:
CKPT_PATH = '/Users/yichi/Desktop/datathon/model.pth.tar'
N_CLASSES = 14
CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
DATA_DIR = '/Users/yichi/Desktop/datathon/images'
TEST_IMAGE_LIST = '/Users/yichi/Desktop/datathon/test_list.txt'
BATCH_SIZE = 64

In [30]:
model = DenseNet121(N_CLASSES).to(device)
model = torch.nn.DataParallel(model).to(device)

if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    checkpoint = torch.load(CKPT_PATH, map_location=torch.device("mps"))

    # If it's a full checkpoint with 'state_dict', extract it
    if 'state_dict' in checkpoint:
        modelCheckpoint = checkpoint['state_dict']
    else:
        modelCheckpoint = checkpoint

    # Fix any key renaming (only if needed)
    new_state_dict = {}
    for k in list(modelCheckpoint.keys()):
        try:
            index = k.rindex('.')
            if k[index - 1] in ('1', '2'):
                new_key = k[:index - 2] + k[index - 1:]
                new_state_dict[new_key] = modelCheckpoint[k]
            else:
                new_state_dict[k] = modelCheckpoint[k]
        except ValueError:
            new_state_dict[k] = modelCheckpoint[k]

    model.load_state_dict(new_state_dict)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                [0.229, 0.224, 0.225])

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.TenCrop(224),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                    ]))
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=0, pin_memory=True)


=> loading checkpoint


  checkpoint = torch.load(CKPT_PATH, map_location=torch.device("mps"))


=> loaded checkpoint


In [33]:
# gt = torch.FloatTensor()
# gt = gt.to(device)
# pred = torch.FloatTensor()
# pred = pred.to(device)

# # switch to evaluate mode
# model.eval()

# for i, (inp, target) in enumerate(test_loader):
#     target = target.to(device)
#     gt = torch.cat((gt, target), 0)
#     bs, n_crops, c, h, w = inp.size()
#     input_var = torch.autograd.Variable(inp.view(-1, c, h, w).to(device), volatile=True)
#     output = model(input_var)
#     output_mean = output.view(bs, n_crops, -1).mean(1)
#     pred = torch.cat((pred, output_mean.data), 0)

# AUROCs = compute_AUCs(gt, pred)
# AUROC_avg = np.array(AUROCs).mean()
# print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
# for i in range(N_CLASSES):
#     print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))

gt = torch.FloatTensor().to(device)
pred = torch.FloatTensor().to(device)

# Switch to evaluation mode
model.eval()

for i, (inp, target) in enumerate(test_loader):
    target = target.to(device)
    gt = torch.cat((gt, target), dim=0)

    bs, n_crops, c, h, w = inp.size()

    with torch.no_grad():
        input_var = inp.view(-1, c, h, w).to(device)
        output = model(input_var)

    output_mean = output.view(bs, n_crops, -1).mean(1)
    pred = torch.cat((pred, output_mean), dim=0)

# Evaluate AUROC
AUROCs = compute_AUCs(gt, pred)
AUROC_avg = np.array(AUROCs).mean()

print(f'\n✅ Average AUROC: {AUROC_avg:.3f}')
for i in range(N_CLASSES):
    print(f'AUROC for {CLASS_NAMES[i]}: {AUROCs[i]:.3f}')



✅ Average AUROC: 0.883
AUROC for Atelectasis: 0.861
AUROC for Cardiomegaly: 0.964
AUROC for Effusion: 0.920
AUROC for Infiltration: 0.759
AUROC for Mass: 0.915
AUROC for Nodule: 0.797
AUROC for Pneumonia: 0.769
AUROC for Pneumothorax: 0.927
AUROC for Consolidation: 0.865
AUROC for Edema: 0.948
AUROC for Emphysema: 0.955
AUROC for Fibrosis: 0.869
AUROC for Pleural_Thickening: 0.826
AUROC for Hernia: 0.988
