In [1]:
import clip
import argparse
import numpy as np

from clip.trainer import CLIPTrainer
from loaders import CxrDataLoader

In [2]:
args = argparse.Namespace(
    dataset_name='iu-xray',
    batch_size=16,
    num_workers=1,
    max_seq_length=248
)

In [3]:
model, preprocess = clip.load("ViT-B/32", load_from_clip=True)

input_resolution = model.visual.input_resolution
context_length = model.context_length
extended_context_length = model.extended_context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Original context length:", context_length)
print("Extended context length:", extended_context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,899,919
Input resolution: 224
Original context length: 77
Extended context length: 248
Vocab size: 49408


In [4]:
data_loader = CxrDataLoader(args, split='train', transform=preprocess)

In [5]:
trainer = CLIPTrainer(model, data_loader, data_loader, data_loader, log_interval=10)

03/10/2025 11:52:54 - INFO - Total parameters: 151,899,919
03/10/2025 11:52:54 - INFO - Trainable parameters: 151,772,943 (99.92%)


In [6]:
trainer.train()

03/10/2025 11:52:54 - INFO - Starting training
03/10/2025 11:52:54 - INFO - Starting epoch 1/415
  4%|▎         | 9/241 [00:21<08:27,  2.19s/it]03/10/2025 11:53:18 - INFO - Step 10: Loss = 3.0192
  8%|▊         | 19/241 [00:42<07:51,  2.12s/it]03/10/2025 11:53:38 - INFO - Step 20: Loss = 2.9861
 12%|█▏        | 29/241 [01:02<07:08,  2.02s/it]03/10/2025 11:53:59 - INFO - Step 30: Loss = 3.0206
 16%|█▌        | 39/241 [01:23<07:00,  2.08s/it]03/10/2025 11:54:20 - INFO - Step 40: Loss = 2.9549
 20%|██        | 49/241 [01:43<06:33,  2.05s/it]03/10/2025 11:54:40 - INFO - Step 50: Loss = 2.8166
 24%|██▍       | 59/241 [02:03<06:05,  2.01s/it]03/10/2025 11:55:00 - INFO - Step 60: Loss = 2.7244
 29%|██▊       | 69/241 [02:24<06:00,  2.10s/it]03/10/2025 11:55:21 - INFO - Step 70: Loss = 2.8454
 33%|███▎      | 79/241 [02:45<05:26,  2.02s/it]03/10/2025 11:55:42 - INFO - Step 80: Loss = 2.8789
 37%|███▋      | 89/241 [03:06<05:16,  2.08s/it]03/10/2025 11:56:04 - INFO - Step 90: Loss = 2.7112
 41%