In [1]:
import clip
import argparse
import numpy as np

from loaders import CxrDataLoader
from datasets import _build_iu_xray_sampler
from classifier.model import ChexpertXRClassifier
from classifier.trainer import ChexpertClassifierTrainer

In [2]:
LABELS = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture',
          'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax',
          'Support Devices']

args = argparse.Namespace(
    dataset_name='iu-xray',
    batch_size=16,
    num_workers=0,
    max_seq_length=248
)

In [3]:
clip_model, preprocess = clip.load("ViT-B/32", load_from_clip=True)
clip_model.eval()

vision_model = clip_model.visual
input_dim = clip_model.vision_width
hidden_dim = clip_model.transformer_width

model = ChexpertXRClassifier(vision_model, input_dim, hidden_dim, num_classes=len(LABELS))

In [4]:
sampler = _build_iu_xray_sampler('train')

In [5]:
dataloaders = {
    'train': CxrDataLoader(args, split='train', transform=preprocess, sampler=sampler),
    'val': CxrDataLoader(args, split='val', transform=preprocess),
    'test': CxrDataLoader(args, split='test', transform=preprocess)
}

In [6]:
trainer = ChexpertClassifierTrainer(model, train_loader=dataloaders['train'], val_loader=dataloaders['val'], log_interval=10)

2025-04-05 22:51:13,392 - INFO - Total parameters: 88,250,126
2025-04-05 22:51:13,393 - INFO - Trainable parameters: 88,250,126 (100.00%)


In [7]:
trainer.train(epochs=10)

2025-04-05 22:51:13,460 - INFO - Start training
2025-04-05 22:51:13,462 - INFO - Starting epoch 1/10
  5%|▌         | 9/169 [00:11<03:31,  1.32s/it]2025-04-05 22:51:26,590 - INFO - Step 10: Loss = 0.3460
 11%|█         | 19/169 [00:24<03:05,  1.23s/it]2025-04-05 22:51:39,179 - INFO - Step 20: Loss = 0.1695
 17%|█▋        | 29/169 [00:37<03:01,  1.29s/it]2025-04-05 22:51:51,925 - INFO - Step 30: Loss = 0.2207
 23%|██▎       | 39/169 [00:49<02:45,  1.27s/it]2025-04-05 22:52:04,602 - INFO - Step 40: Loss = 0.3786
 29%|██▉       | 49/169 [01:02<02:30,  1.26s/it]2025-04-05 22:52:17,216 - INFO - Step 50: Loss = 0.3620
 35%|███▍      | 59/169 [01:15<02:18,  1.26s/it]2025-04-05 22:52:29,779 - INFO - Step 60: Loss = 0.3945
 41%|████      | 69/169 [01:27<02:07,  1.28s/it]2025-04-05 22:52:42,532 - INFO - Step 70: Loss = 0.4019
 47%|████▋     | 79/169 [01:40<01:54,  1.28s/it]2025-04-05 22:52:55,335 - INFO - Step 80: Loss = 0.3331
 53%|█████▎    | 89/169 [01:53<01:42,  1.28s/it]2025-04-05 22:53:08,

KeyboardInterrupt: 

In [8]:
test_loader = dataloaders['test']
test_loader_iter = iter(test_loader)
batch = next(test_loader_iter)

labels = batch['labels'][0]
labels[labels == -1] = 0.0
image = batch['image'][0]
image = image.unsqueeze(0)
predictions = model(image)

print(predictions['probs'])
print(labels)

tensor([[0.0911, 0.2139, 0.0697, 0.0712, 0.0050, 0.1004, 0.0641, 0.3036, 0.0780,
         0.1526, 0.0532, 0.0622, 0.1173, 0.1548]], grad_fn=<SigmoidBackward0>)
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
