In [1]:
import clip
import argparse
import numpy as np

from clip.trainer import CLIPTrainer
from loaders import CxrDataLoader



In [7]:
args = argparse.Namespace(
    dataset_name='mimic-cxr',
    batch_size=64,
    num_workers=0,
    max_seq_length=248,
    use_minio=False
)

In [6]:
model, preprocess = clip.load("ViT-B/32", load_from_clip=True, extended_context=True)

input_resolution = model.visual.input_resolution
context_length = model.context_length
extended_context_length = model.extended_context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Original context length:", context_length)
print("Extended context length:", extended_context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,899,919
Input resolution: 224
Original context length: 77
Extended context length: 248
Vocab size: 49408


In [8]:
dataloaders = {
    'train': CxrDataLoader(args, split='train', transform=preprocess),
    'val': CxrDataLoader(args, split='val', transform=preprocess),
    'test': CxrDataLoader(args, split='test', transform=preprocess)
}

In [9]:
trainer = CLIPTrainer(
    model,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    test_loader=dataloaders['test'],
    log_interval=50
)

2025-04-11 21:43:06,131 - INFO - Total parameters: 151,899,919
2025-04-11 21:43:06,131 - INFO - Trainable parameters: 151,364,865 (99.65%)


In [10]:
trainer.train(epochs=5)

2025-04-11 21:43:13,550 - INFO - Starting training
2025-04-11 21:43:13,552 - INFO - Starting epoch 1/5
  0%|          | 1/3427 [00:14<14:13:36, 14.95s/it]


KeyboardInterrupt: 