In [3]:
import yaml
import argparse

from blip.blip import BLIP
from blip.pretrain import BLIPTrainer
from loaders import CxrDataLoader, build_transform

In [5]:
config_path = '../configs/pretrain.yaml'
config = yaml.load(open(config_path), Loader=yaml.Loader)

In [6]:
args = argparse.Namespace(
    dataset_name='mimic-cxr',
    batch_size=config['batch_size'],
    num_workers=0,
    drop_last=True,
    use_minio=False
)

In [4]:
transforms = {
    'train': build_transform('train', config['image_size']),
    'val': build_transform('val', config['image_size']),
    'test': build_transform('test', config['image_size']),
}

In [5]:
dataloaders = {
    'train': CxrDataLoader(args, split='train', transform=transforms['train']),
    'val': CxrDataLoader(args, split='val', transform=transforms['val']),
    'test': CxrDataLoader(args, split='test', transform=transforms['test']),
}

In [6]:
model = BLIP(
    image_size=config['image_size'], 
    vit=config['vit'], 
    vit_grad_ckpt=config['vit_grad_ckpt'], 
    vit_ckpt_layer=config['vit_ckpt_layer'], 
    queue_size=config['queue_size']
)

/embeddings/word_embeddings is tied
/embeddings/position_embeddings is tied
/embeddings/LayerNorm is tied
/encoder/layer/0/crossattention/self/query is tied
/encoder/layer/0/crossattention/self/key is tied
/encoder/layer/0/crossattention/self/value is tied
/encoder/layer/0/crossattention/output/dense is tied
/encoder/layer/0/crossattention/output/LayerNorm is tied
/encoder/layer/0/intermediate/dense is tied
/encoder/layer/0/output/dense is tied
/encoder/layer/0/output/LayerNorm is tied
/encoder/layer/1/crossattention/self/query is tied
/encoder/layer/1/crossattention/self/key is tied
/encoder/layer/1/crossattention/self/value is tied
/encoder/layer/1/crossattention/output/dense is tied
/encoder/layer/1/crossattention/output/LayerNorm is tied
/encoder/layer/1/intermediate/dense is tied
/encoder/layer/1/output/dense is tied
/encoder/layer/1/output/LayerNorm is tied
/encoder/layer/2/crossattention/self/query is tied
/encoder/layer/2/crossattention/self/key is tied
/encoder/layer/2/crossat

In [7]:
trainer = BLIPTrainer(
    model,
    config,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    test_loader=dataloaders['test'],
    mixed_precision=False
)

2025-04-19 17:03:10,454 - INFO - Total parameters: 475,892,799
2025-04-19 17:03:10,454 - INFO - Trainable parameters: 252,441,919 (53.05%)


In [8]:
trainer.train()

2025-04-19 17:03:10,524 - INFO - Start training
2025-04-19 17:03:10,524 - INFO - Starting epoch 1/5
Train:   0%|          | 0/1714 [00:23<?, ?it/s]


KeyboardInterrupt: 