In [1]:
import clip
import argparse
import numpy as np

from clip.trainer import CLIPTrainer
from loaders import CxrDataLoader
from constants import VIT_TYPE

In [2]:
args = argparse.Namespace(
    dataset_name='iu-xray',
    batch_size=16,
    num_workers=0,
    max_seq_length=248
)

In [4]:
model, preprocess = clip.load("ViT-B/32", load_from_clip=True, extended_context=False)

input_resolution = model.visual.input_resolution
context_length = model.context_length
extended_context_length = model.extended_context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Original context length:", context_length)
print("Extended context length:", extended_context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,685,391
Input resolution: 224
Original context length: 77
Extended context length: 248
Vocab size: 49408


In [4]:
dataloaders = {
    'train': CxrDataLoader(args, split='train', transform=preprocess),
    'val': CxrDataLoader(args, split='val', transform=preprocess),
    'test': CxrDataLoader(args, split='test', transform=preprocess)
}

In [5]:
trainer = CLIPTrainer(
    model,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    test_loader=dataloaders['test'],
    log_interval=20
)

Computing pos weights: 100%|██████████| 169/169 [01:44<00:00,  1.62it/s]
2025-03-30 16:36:10,749 - INFO - Total parameters: 151,899,919
2025-03-30 16:36:10,749 - INFO - Trainable parameters: 151,772,943 (99.92%)


In [6]:
trainer.train(epochs=5)

2025-03-30 16:36:10,765 - INFO - Starting training
2025-03-30 16:36:10,766 - INFO - Starting epoch 1/5
 11%|█         | 19/169 [00:53<06:58,  2.79s/it]2025-03-30 16:37:07,059 - INFO - Step 20: Loss = 3.4481
 23%|██▎       | 39/169 [01:48<05:56,  2.74s/it]2025-03-30 16:38:02,304 - INFO - Step 40: Loss = 3.7022
 35%|███▍      | 59/169 [02:44<05:02,  2.75s/it]2025-03-30 16:38:57,634 - INFO - Step 60: Loss = 3.1322
 47%|████▋     | 79/169 [03:39<04:08,  2.77s/it]2025-03-30 16:39:53,138 - INFO - Step 80: Loss = 3.2594
 59%|█████▊    | 99/169 [04:36<03:23,  2.90s/it]2025-03-30 16:40:50,149 - INFO - Step 100: Loss = 3.7771
 70%|███████   | 119/169 [05:32<02:23,  2.87s/it]2025-03-30 16:41:46,649 - INFO - Step 120: Loss = 3.3033
 82%|████████▏ | 139/169 [06:30<01:24,  2.82s/it]2025-03-30 16:42:43,559 - INFO - Step 140: Loss = 2.7930
 94%|█████████▍| 159/169 [07:25<00:27,  2.78s/it]2025-03-30 16:43:39,133 - INFO - Step 160: Loss = 3.0705
100%|██████████| 169/169 [07:52<00:00,  2.79s/it]
2025-03-