In [1]:
import clip
import argparse
import numpy as np

from loaders import CxrDataLoader
from classifier.model import ChexpertClassifier
from classifier.trainer import ChexpertClassifierTrainer

In [2]:
LABELS = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture',
          'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax',
          'Support Devices']

args = argparse.Namespace(
    dataset_name='iu-xray',
    batch_size=16,
    num_workers=0,
    max_seq_length=248
)

In [3]:
clip_model, preprocess = clip.load("ViT-B/32", load_from_clip=True)
clip_model.eval()

vision_model = clip_model.visual
input_dim = clip_model.vision_width
hidden_dim = clip_model.transformer_width

model = ChexpertClassifier(vision_model, input_dim, hidden_dim, num_classes=len(LABELS))

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")

Model parameters: 88,250,126


In [4]:
dataloaders = {
    'train': CxrDataLoader(args, split='train', transform=preprocess),
    'val': CxrDataLoader(args, split='val', transform=preprocess),
    'test': CxrDataLoader(args, split='test', transform=preprocess)
}

In [5]:
trainer = ChexpertClassifierTrainer(model, train_loader=dataloaders['train'], val_loader=dataloaders['val'], log_interval=10)

03/19/2025 01:04:31 - INFO - Total parameters: 88,250,126
03/19/2025 01:04:31 - INFO - Trainable parameters: 400,910 (0.45%)


In [6]:
# TODO: check if evaluation metrics are computed correctly: log predicted labels
# TODO: handle uncertain labels
trainer.train(epochs=1)

03/19/2025 01:04:31 - INFO - Start training
03/19/2025 01:04:31 - INFO - Starting epoch 1/1
  5%|▌         | 9/169 [00:07<02:10,  1.23it/s]03/19/2025 01:04:40 - INFO - Step 10: Loss = 0.6932
 11%|█         | 19/169 [00:16<02:11,  1.14it/s]03/19/2025 01:04:48 - INFO - Step 20: Loss = 0.6931
 17%|█▋        | 29/169 [00:24<01:54,  1.22it/s]03/19/2025 01:04:57 - INFO - Step 30: Loss = 0.6931
 23%|██▎       | 39/169 [00:32<01:50,  1.18it/s]03/19/2025 01:05:05 - INFO - Step 40: Loss = 0.6931
 29%|██▉       | 49/169 [00:41<01:40,  1.20it/s]03/19/2025 01:05:14 - INFO - Step 50: Loss = 0.6931
 35%|███▍      | 59/169 [00:49<01:30,  1.22it/s]03/19/2025 01:05:22 - INFO - Step 60: Loss = 0.6931
 41%|████      | 69/169 [00:57<01:24,  1.18it/s]03/19/2025 01:05:30 - INFO - Step 70: Loss = 0.6931
 47%|████▋     | 79/169 [01:06<01:12,  1.23it/s]03/19/2025 01:05:38 - INFO - Step 80: Loss = 0.6931
 53%|█████▎    | 89/169 [01:14<01:06,  1.20it/s]03/19/2025 01:05:47 - INFO - Step 90: Loss = 0.6931
 59%|████