In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import os
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict
from torch import nn, optim
from torchvision import datasets, transforms, utils, models
import torchvision
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
from torchvision.models import densenet121

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device

device(type='mps')

In [3]:
# Load metadata
metadata = pd.read_csv("/Users/yichi/Desktop/datathon/Data_Entry_2017.csv")

# Drop missing entries
metadata = metadata.dropna(subset=["Patient Age", "Patient Gender"])

# Normalize age
age_mean = metadata["Patient Age"].mean()
age_std = metadata["Patient Age"].std()
metadata["age_scaled"] = (metadata["Patient Age"] - age_mean) / age_std

# Encode gender: F=0, M=1
metadata["gender_encoded"] = metadata["Patient Gender"].map({"F": 0, "M": 1})

# Create dictionary: image_name → [age, gender]
age_gender_dict = {
    row["Image Index"]: [row["age_scaled"], row["gender_encoded"]]
    for _, row in metadata.iterrows()
}

In [4]:
class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, image_list_file, transform=None):
        self.data_dir = data_dir

        image_names = []
        labels = []
        with open(image_list_file, "r") as f:
            for line in f:
                items = line.split()
                image_name = items[0]
                label = [int(i) for i in items[1:]]
                image_names.append(image_name)
                labels.append(label)

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, img_name, label

    def __len__(self):
        return len(self.image_names)

In [5]:
class ChestXrayWithAux(Dataset):
    def __init__(self, data_dir, image_list_file, age_gender_dict, transform=None):
        self.data_dir = data_dir
        self.age_gender_dict = age_gender_dict
        self.transform = transform

        self.image_names = []
        self.labels = []

        with open(image_list_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                self.image_names.append(parts[0])
                self.labels.append([int(x) for x in parts[1:]])

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.data_dir, img_name)

        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        aux = self.age_gender_dict.get(img_name, [0.0, 0])
        aux_tensor = torch.tensor(aux, dtype=torch.float32)

        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, aux_tensor, label

    def __len__(self):
        return len(self.image_names)

In [6]:
class LatentFusionModel(nn.Module):
    def __init__(self, base_densenet_model, aux_input_dim=2, aux_hidden_dim=64, num_classes=14):
        super(LatentFusionModel, self).__init__()

        # Use pretrained DenseNet model (CheXNet) as encoder
        self.features = base_densenet_model.features
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.cnn_output_dim = base_densenet_model.classifier.in_features  # 1024

        # MLP for auxiliary data
        self.aux_net = nn.Sequential(
            nn.Linear(aux_input_dim, aux_hidden_dim),
            nn.ReLU(),
            nn.Linear(aux_hidden_dim, aux_hidden_dim),
            nn.ReLU()
        )

        # Classifier head after fusion
        self.classifier = nn.Sequential(
            nn.Linear(self.cnn_output_dim + aux_hidden_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, aux_features):
        x = self.features(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)

        aux_emb = self.aux_net(aux_features)
        fused = torch.cat([x, aux_emb], dim=1)

        return self.classifier(fused)

In [7]:
TRAIN_LIST = "/Users/yichi/Desktop/datathon/train_list.txt"
IMAGE_DIR = "/Users/yichi/Desktop/datathon/images"

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


train_dataset = ChestXrayWithAux(
    data_dir=IMAGE_DIR,
    image_list_file=TRAIN_LIST,
    age_gender_dict=age_gender_dict,
    transform=data_transforms
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=0, pin_memory=True
)

In [8]:
model = densenet121(pretrained=False)
model.classifier = nn.Linear(1024, 14)

CKPT_PATH = '/Users/yichi/Desktop/datathon/datathon_team6/model.pth.tar'

if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    checkpoint = torch.load(CKPT_PATH, map_location=torch.device("mps"))

    modelCheckpoint = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint

    new_state_dict = {}
    for k in list(modelCheckpoint.keys()):
        v = modelCheckpoint[k]

        # Remove prefixes
        if k.startswith("module.densenet121."):
            k = k[len("module.densenet121."):]
        elif k.startswith("module."):
            k = k[len("module."):]

        # Fix classifier.0.* → classifier.*
        if k.startswith("classifier.0."):
            k = k.replace("classifier.0.", "classifier.")

        # Custom renaming logic (optional)
        try:
            index = k.rindex('.')
            if k[index - 1] in ('1', '2'):
                k = k[:index - 2] + k[index - 1:]
        except ValueError:
            pass

        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

  checkpoint = torch.load(CKPT_PATH, map_location=torch.device("mps"))


=> loading checkpoint
=> loaded checkpoint


In [None]:
fusion_model = LatentFusionModel(base_densenet_model=model).to(device)
fusion_model = torch.nn.DataParallel(fusion_model).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fusion_model.parameters(), lr=1e-4)

num_epochs = 5

for epoch in range(num_epochs):
    fusion_model.train()
    running_loss = 0.0

    for images, aux_features, labels in trainloader:
        images = images.to(device)
        aux_features = aux_features.to(device)
        labels = labels.to(device)

        outputs = fusion_model(images, aux_features)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(trainloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # # AUC EVALUATION (takes forever to run)
    # fusion_model.eval()
    # all_preds = []
    # all_targets = []
    # with torch.no_grad():
    #     for images, aux_features, labels in trainloader:
    #         images = images.to(device)
    #         aux_features = aux_features.to(device)
    #         labels = labels.to(device)

    #         outputs = fusion_model(images, aux_features)
    #         probs = torch.sigmoid(outputs)
    #         all_preds.append(probs.cpu())
    #         all_targets.append(labels.cpu())

    # all_preds = torch.cat(all_preds).numpy()
    # all_targets = torch.cat(all_targets).numpy()

    # aucs = []
    # for i in range(all_targets.shape[1]):
    #     try:
    #         auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
    #     except ValueError:
    #         auc = float('nan')
    #     aucs.append(auc)

    # mean_auc = np.nanmean(aucs)
    # print(f"Epoch [{epoch+1}/{num_epochs}], Train AUC: {mean_auc:.4f}")

In [None]:
TEST_LIST = "/Users/yichi/Desktop/datathon/test_list.txt"

test_dataset = ChestXrayWithAux(
    data_dir=IMAGE_DIR,
    image_list_file=TEST_LIST,
    age_gender_dict=age_gender_dict,
    transform=data_transforms

testloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0
)