In [50]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import os
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict
from torch import nn, optim
from torchvision import datasets, transforms, utils, models
import torchvision
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F

In [18]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device

device(type='mps')

In [22]:
class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, image_list_file, transform=None):
        """
        Args:
            data_dir: path to image directory.
            image_list_file: path to the file containing images
                with corresponding labels.
            transform: optional transform to be applied on a sample.
        """
        image_names = []
        labels = []
        with open(image_list_file, "r") as f:
            for line in f:
                items = line.split()
                image_name= items[0]
                label = items[1:]
                label = [int(i) for i in label]
                image_name = os.path.join(data_dir, image_name)
                image_names.append(image_name)
                labels.append(label)

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item

        Returns:
            image and its labels
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)

In [26]:
class DenseNet121(nn.Module):
    """Model modified.

    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.

    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x

In [31]:
def compute_AUCs(gt, pred):
    """Computes Area Under the Curve (AUC) from prediction scores.

    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.

    Returns:
        List of AUROCs of all classes.
    """
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

In [27]:
import torchvision.transforms as transforms

TRAIN_LIST = "/Users/yichi/Desktop/datathon/train_list.txt"
IMAGE_DIR = "/Users/yichi/Desktop/datathon/images"

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset = ChestXrayDataSet(IMAGE_DIR, TRAIN_LIST, transform=data_transforms)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)

In [28]:
images, labels = next(iter(trainloader))
print("Images shape:", images.shape) 
print("Labels shape:", labels.shape)

Images shape: torch.Size([64, 3, 224, 224])
Labels shape: torch.Size([64, 14])


In [29]:
CKPT_PATH = '/Users/yichi/Desktop/datathon/model.pth.tar'
N_CLASSES = 14
CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
DATA_DIR = '/Users/yichi/Desktop/datathon/images'
TEST_IMAGE_LIST = '/Users/yichi/Desktop/datathon/test_list.txt'
BATCH_SIZE = 64

In [42]:
model = DenseNet121(N_CLASSES).to(device)
model = torch.nn.DataParallel(model).to(device)

if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    checkpoint = torch.load(CKPT_PATH, map_location=torch.device("mps"))
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                [0.229, 0.224, 0.225])

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.TenCrop(224),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                    ]))
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=0, pin_memory=True)




=> loading checkpoint


  checkpoint = torch.load(CKPT_PATH, map_location=torch.device("mps"))


=> loaded checkpoint


In [45]:
# gt = torch.FloatTensor()
# gt = gt.to(device)
# pred = torch.FloatTensor()
# pred = pred.to(device)

# # switch to evaluate mode
# model.eval()

# for i, (inp, target) in enumerate(test_loader):
#     target = target.to(device)
#     gt = torch.cat((gt, target), 0)
#     bs, n_crops, c, h, w = inp.size()
#     input_var = torch.autograd.Variable(inp.view(-1, c, h, w).to(device), volatile=True)
#     output = model(input_var)
#     output_mean = output.view(bs, n_crops, -1).mean(1)
#     pred = torch.cat((pred, output_mean.data), 0)

# AUROCs = compute_AUCs(gt, pred)
# AUROC_avg = np.array(AUROCs).mean()
# print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
# for i in range(N_CLASSES):
#     print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))

gt = torch.FloatTensor().to(device)
pred = torch.FloatTensor().to(device)

# Switch to evaluation mode
model.eval()

for i, (inp, target) in enumerate(test_loader):
    target = target.to(device)
    gt = torch.cat((gt, target), dim=0)

    bs, n_crops, c, h, w = inp.size()

    with torch.no_grad():
        input_var = inp.view(-1, c, h, w).to(device)
        output = model(input_var)

    output_mean = output.view(bs, n_crops, -1).mean(1)
    pred = torch.cat((pred, output_mean), dim=0)

# Evaluate AUROC
AUROCs = compute_AUCs(gt, pred)
AUROC_avg = np.array(AUROCs).mean()

print(f'\n✅ Average AUROC: {AUROC_avg:.3f}')
for i in range(N_CLASSES):
    print(f'AUROC for {CLASS_NAMES[i]}: {AUROCs[i]:.3f}')



✅ Average AUROC: 0.546
AUROC for Atelectasis: 0.469
AUROC for Cardiomegaly: 0.628
AUROC for Effusion: 0.570
AUROC for Infiltration: 0.621
AUROC for Mass: 0.544
AUROC for Nodule: 0.441
AUROC for Pneumonia: 0.335
AUROC for Pneumothorax: 0.468
AUROC for Consolidation: 0.652
AUROC for Edema: 0.599
AUROC for Emphysema: 0.552
AUROC for Fibrosis: 0.566
AUROC for Pleural_Thickening: 0.471
AUROC for Hernia: 0.728


In [48]:
core_model = model.module
core_model

DenseNet121(
  (densenet121): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, mome

In [49]:
feature_extractor = core_model.densenet121.features
feature_extractor

Sequential(
  (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu0): ReLU(inplace=True)
  (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (denseblock1): _DenseBlock(
    (denselayer1): _DenseLayer(
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (denselayer2): _DenseLayer(
      (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(96, 128, ke

In [51]:
def extract_features_from_image(image_tensor):
    """
    image_tensor: a single image tensor of shape (3, H, W), already normalized
    Returns: feature vector of shape (1, 1024)
    """
    image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dim

    with torch.no_grad():
        features = feature_extractor(image_tensor)         # (1, 1024, 7, 7)
        pooled = F.adaptive_avg_pool2d(features, (1, 1))   # (1, 1024, 1, 1)
        flat = pooled.view(pooled.size(0), -1)             # (1, 1024)
    
    return flat 