In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image

Defining the data directories and csv file names

In [2]:
path = 'data/MURA-v1.1'
train_image_paths_csv = "train_image_paths.csv"
valid_image_paths_csv = "valid_image_paths.csv"

Reading the CSV file for training images and extracting the image paths

In [3]:
train_images_paths = pd.read_csv(os.path.join(path,train_image_paths_csv),dtype=str,header=None)
train_images_paths.columns = ['image_path']

valid_images_paths = pd.read_csv(os.path.join(path,valid_image_paths_csv),dtype=str,header=None)
valid_images_paths.columns = ['image_path']

Extracting the labels

In [4]:
train_images_paths['label'] = train_images_paths['image_path'].map(
    lambda x: 1 if 'positive' in x else 0)

valid_images_paths['label'] = valid_images_paths['image_path'].map(
    lambda x: 1 if 'positive' in x else 0)

Extracting other relevant information and displaying the dataframe

In [5]:
train_images_paths['category']  = train_images_paths['image_path'].apply(
    lambda x: x.split('/')[2])  
train_images_paths['patientId']  = train_images_paths['image_path'].apply(
    lambda x: x.split('/')[3].replace('patient',''))

valid_images_paths['category']  = valid_images_paths['image_path'].apply(
    lambda x: x.split('/')[2])  
valid_images_paths['patientId']  = valid_images_paths['image_path'].apply(
    lambda x: x.split('/')[3].replace('patient',''))

train_images_paths.head()

Unnamed: 0,image_path,label,category,patientId
0,MURA-v1.1/train/XR_SHOULDER/patient00001/study...,1,XR_SHOULDER,1
1,MURA-v1.1/train/XR_SHOULDER/patient00001/study...,1,XR_SHOULDER,1
2,MURA-v1.1/train/XR_SHOULDER/patient00001/study...,1,XR_SHOULDER,1
3,MURA-v1.1/train/XR_SHOULDER/patient00002/study...,1,XR_SHOULDER,2
4,MURA-v1.1/train/XR_SHOULDER/patient00002/study...,1,XR_SHOULDER,2


Printing train dataset information

In [6]:
total_number_of_training_images = np.shape(train_images_paths)[0]
print("total number of images:",total_number_of_training_images )
print ("\nnumber of null values\n", train_images_paths.isnull().sum())


categories_counts = pd.DataFrame(train_images_paths['category'].value_counts())
print ('\n\ncategories:\n',categories_counts )
print('\n\nnumber of patients:',train_images_paths['patientId'].nunique())
print('\nnumber of labels:',train_images_paths['label'].nunique())
print ('\npositive casses:',len(train_images_paths[train_images_paths['label']==1]))
print ('\nnegative casses:',len(train_images_paths[train_images_paths['label']==0]))

total number of images: 36808

number of null values
 image_path    0
label         0
category      0
patientId     0
dtype: int64


categories:
              count
category          
XR_WRIST      9752
XR_SHOULDER   8379
XR_HAND       5543
XR_FINGER     5106
XR_ELBOW      4931
XR_FOREARM    1825
XR_HUMERUS    1272


number of patients: 11184

number of labels: 2

positive casses: 14873

negative casses: 21935


In [7]:
counts = (
    train_images_paths
    .groupby(['category', 'patientId'])
    .size()
    .reset_index(name='n_images')
)

dist = (
    counts
    .groupby(['category', 'n_images'])
    .size()
    .reset_index(name='n_patients')
    .sort_values(['category', 'n_images'])
)

pivot = dist.pivot(
    index='n_images', columns='category', values='n_patients').fillna(0).astype(int)

print(pivot)

category  XR_ELBOW  XR_FINGER  XR_FOREARM  XR_HAND  XR_HUMERUS  XR_SHOULDER  \
n_images                                                                      
1               31        336          84       28          22          289   
2              691        155         651      484         478          487   
3              608       1172          91     1306          61          920   
4              272        132          32       65          20          834   
5               68         27           5       19           5           82   
6               27         30           1       36           1           33   
7               11          6           1        3           0           21   
8                1          5           0        2           0           17   
9                1          1           0        0           0            1   
10               1          1           0        1           0            2   
11               0          0           0        1  

Define augmentations

In [8]:
train_tf = T.Compose([
    T.Grayscale(num_output_channels=1),
    T.Resize((256,256)),
    T.RandomRotation(5),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomAffine(degrees=0, translate=(0.1,0.1), scale=(0.9,1.1)),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])   # samplewise ≈ per-image norm
])

valid_tf = T.Compose([
    T.Grayscale(num_output_channels=1),
    T.Resize((256,256)),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

Define pytorch dataset objects

In [9]:
class PatientDataset(Dataset):
    """
    Each item is one patient's full set of X-ray images plus their label.
    """
    def __init__(self, df, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # Group by patient; keep the first label (assume all images share it)
        self.patients = (
            df.groupby("patientId")
            .agg(image_paths=("image_path", list), label=("label", "first"))
            .reset_index()
        )

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        row = self.patients.iloc[idx]
        imgs = []
        for p in row["image_paths"]:
            img = Image.open(os.path.join(self.root_dir, p)).convert("L")
            if self.transform:
                img = self.transform(img)
            imgs.append(img)

        # (N, 1, H, W) — variable N per patient
        images = torch.stack(imgs, dim=0)
        label = torch.tensor(row["label"], dtype=torch.float32)
        return images, label


def patient_collate_fn(batch):
    """
    Returns a list of image tensors (one per patient) and a stacked label tensor.
    """
    image_list = [item[0] for item in batch]   # list of (N_i, 1, H, W)
    labels = torch.stack([item[1] for item in batch])  # (B,)
    return image_list, labels


Create datasets and dataloaders

In [10]:
parent_path = 'data'

train_ds = PatientDataset(train_images_paths, parent_path, train_tf)
valid_ds = PatientDataset(valid_images_paths, parent_path, valid_tf)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True,
                          num_workers=2, pin_memory=True,
                          collate_fn=patient_collate_fn)
val_loader   = DataLoader(valid_ds, batch_size=16, shuffle=False,
                          num_workers=2, pin_memory=True,
                          collate_fn=patient_collate_fn)

Sanity check

In [11]:
x,y = next(iter(train_loader))
print('lengths of x, y:\n', len(x), len(y))
print('shapes of x, y:\n', x[0].shape, y[0].shape)

lengths of x, y:
 16 16
shapes of x, y:
 torch.Size([2, 1, 256, 256]) torch.Size([])


In [12]:
x,y = next(iter(val_loader))
print('lengths of x, y:\n', len(x), len(y))
print('shapes of x, y:\n', x[0].shape, y[0].shape)

lengths of x, y:
 16 16
shapes of x, y:
 torch.Size([4, 1, 256, 256]) torch.Size([])


Custom Visual Transformer

In [13]:
# from architectures.custom_vit import Custom_ViT
#vit = Custom_ViT(img_size=256, patch_size=16, in_chans=1,
#     embed_dim=256, depth=6, heads=8)

Pretrained Resnets

In [15]:
#from architectures.resnet50 import ResNet50_Backbone
#backbone = ResNet50_Backbone(embed_dim=256, freeze_until='layer4')

#from architectures.resnet101 import ResNet101_Backbone
#backbone = ResNet101_Backbone(embed_dim=256, freeze_until='layer4')

from architectures.resnet152 import ResNet152_Backbone
backbone = ResNet152_Backbone(embed_dim=256, freeze_until='layer4')

unfreeze_groups = [
    backbone.backbone[7],
    backbone.backbone[6],
    backbone.backbone[5],
    backbone.backbone[4],
]

Pretrained ViTs

In [None]:
# vit_b_16
# vit_b_32
# vit_l_16
# vit_l_32

Parent Model

In [16]:
from architectures.classifier import Classifier
model = Classifier(backbone, embed_dim=256, mlp_depth=2)

In [None]:
from trainer import fit

# class imbalance: ~21935 neg / ~14873 pos ≈ 1.47
model = fit(
    model, train_loader, val_loader,
    n_epochs=50,
    lr=1e-4,
    pos_weight=1.47,
    unfreeze_groups=unfreeze_groups,
    scheduler_patience=3,
    unfreeze_patience=2,
    unfreeze_lr_scale=0.5,
    checkpoint_path="models/best_model_resnet152.pt"
)

[Unfreezer] 4 layer group(s) queued for progressive unfreezing.
