In [5]:
import torch
import torchvision
from torchvision.models import resnet50

In [6]:
# add parent dir for loading helpers
import sys
sys.path.insert(1, '../')

## Load RadImageNet Model

In [None]:
from helpers import radimagenet

backbone = radimagenet.RadImageNetBackbone()
classifier = radimagenet.RadImageNetClassifier(num_class=1)

backbone.load_state_dict(torch.load("../models/radimagenet_resnet50.pt"))

In [20]:
radimagenet_model = nn.Sequential(backbone, classifier)

## Load ImageNet Models

In [11]:
rgb_weights = torch.load("../models/rgb_3c_model_89.pth", map_location='cpu', weights_only=False)
rgb_model = torchvision.models.get_model("resnet50", weights=None, num_classes=1000)
rgb_model.load_state_dict(rgb_weights["model"])
# rgb_model = torch.nn.Sequential(*list(rgb_model.children())[:9])

<All keys matched successfully>

In [21]:
rgb_model = torchvision.models.get_model("resnet50", weights=None, num_classes=1000)
rgb_model.load_state_dict(rgb_weights["model"])

rgb_model = torch.nn.Sequential(*list(rgb_model.children())[:9])

In [None]:
# full_model = get_model("resnet50", weights=None, num_classes=1000)
# weights = torch.load("model.pth", map_location='cpu', weights_only=False)
# full_model.load_state_dict(weights["model"])

In [15]:
model2_backbone = torch.nn.Sequential(*list(backbone.children())[:9])
# model2_classifier = RadImageNetClassifier(num_class=num_class)
model2_comparable = nn.Sequential(model2_backbone, classifier)

## Load Data

In [23]:
from torchvision.transforms import v2
from torch import float32 as tfloat32

# crop dictionary of calculated dataset means and std devs
CROP_DICT = {
    # data      mean         std
    'cxr14': [[162.7414], [44.0700]],
    'openi': [[157.6150], [41.8371]],
    'jsrt': [[161.7889], [41.3950]],
    'padchest': [[160.3638], [44.8449]],
}

# arch segmented dictionary of calculated dataset means and std devs
ARCH_SEG_DICT = {
    # data       mean        std
    'cxr14': [[128.2716], [76.7148]],
    'openi': [[127.7211], [69.7704]],
    'jsrt': [[139.9666], [72.4017]],
    'padchest': [[129.5006], [72.6308]],
    'padcxr14': [[128.8861], [74.6728]]
}

# lung segmented dictionary of calculated dataset means and std devs
LUNG_SEG_DICT = {
    # data       mean        std
    'cxr14': [[60.6809], [68.9660]],
    'openi': [[60.5483], [66.5276]],
    'jsrt': [[66.5978], [72.6493]],
    'padchest': [[60.5482], [66.5276]],
    'padcxr14': [[60.61455], [67.7468]]
}


def get_cxr_eval_transforms(crop_size, normalise):
    """
    Returns evaluation transforms for CXR images. Pass in target 
    crop size and the normalisation method for target dataset.
    """
    cxr_transform_list = [
        v2.ToImage(),
        v2.Resize(size=crop_size, antialias=True),
        v2.ToDtype(tfloat32, scale=False),
        normalise
    ]
    return v2.Compose(cxr_transform_list)


def get_cxr_single_eval_transforms(crop_size, normalise):
    """
    Returns evaluation transforms for single channel output CXR 
    images. Pass in target crop size and the normalisation method 
    for target dataset.
    """
    cxr_transform_list = [
        v2.ToImage(),
        v2.Grayscale(1),
        v2.Resize(size=crop_size, antialias=True),
        v2.ToDtype(tfloat32, scale=False),
        normalise,
    ]
    return v2.Compose(cxr_transform_list)


def get_cxr_dataset_normalisation(dataset, process):
    """
    Returns normalisation transform for given dataset/config. Pass 
    in dataset name and the image processing method used.

    Args:
    - dataset (str): Name of CXR dataset. Expects ("cxr14", "padchest", "openi", "jsrt").
    - process (str): Name of CXR processing applied. Expects ("crop", "arch", "lung").

    Returns:
    - torchvision.transform.V2 normalize method.

    """
    if process.lower() not in ("crop", "arch", "lung"):
        raise ValueError(f"Unexpected CXR processing type: \
            {process}! Please choose from (crop, arch, lung).")
    else:
        if dataset.lower() not in ("cxr14", "padchest", "openi", "jsrt"):
            raise ValueError(f"Unexpected CXR dataset type: \
                {dataset}! Please choose from (cxr14, padchest, \
                openi, jsrt).")
        else:
            return v2.Normalize(CROP_DICT[dataset.lower()][0],
                                CROP_DICT[dataset.lower()][1]) \
                if process.lower() == "crop" \
                else \
                v2.Normalize(ARCH_SEG_DICT[dataset.lower()][0],
                             ARCH_SEG_DICT[dataset.lower()][1]) \
                if process.lower() == "arch" \
                else v2.Normalize(LUNG_SEG_DICT[dataset.lower()][0],
                                  LUNG_SEG_DICT[dataset.lower()][1])



In [24]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
# from helpers.cxr import get_cxr_dataset_normalisation, get_cxr_eval_transforms, get_cxr_single_eval_transforms 
from PIL import Image
import os
from tqdm import tqdm

class ImageFolderWithPaths(ImageFolder):
    """Modifies torchviison ImageFolder to return (img, label, img_path)"""

    def __getitem__(self, index):

        img, label = super(ImageFolderWithPaths, self).__getitem__(index)

        path = self.imgs[index][0]

        return (img, label, path)

def load_dataset_with_paths_direct_path(dataset_path, dataset_name, process="arch", 
    crop_size=512, batch_size=4, shuffle=True, single_channel=False):
    """Wrapper helper to load a dataset with img paths."""
    
    dataset = ImageFolderWithPaths(
        root = dataset_path,
        transform = get_cxr_eval_transforms(
            crop_size = crop_size,
            normalise = get_cxr_dataset_normalisation(
                dataset = dataset_name, 
                process = process
                )
            ) if not single_channel else
        get_cxr_single_eval_transforms(
            crop_size = crop_size,
            normalise = get_cxr_dataset_normalisation(
                dataset = dataset_name, 
                process = process
                )
            )
        )
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)