In [None]:
# only required for cluster training in order to detect python packages from the project root
# import os
# import sys
#
# project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))  # add to sys path the project root to be able to use in-repo packages
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)
#
# print(sys.executable)
# for p in sys.path:
#     print(p)

In [1]:
import yaml
import blip
import argparse

from blip.pretrain import BLIPTrainer
from data.loaders import CxrDataLoader, TransformBuilder

In [2]:
config_path = '../configs/pretrain.yaml'
config = yaml.load(open(config_path), Loader=yaml.Loader)

In [3]:
args = argparse.Namespace(
    dataset_name='mimic-cxr',
    batch_size=config['batch_size'],
    num_workers=0,
    drop_last=True,
    use_minio=False
)

In [4]:
transform_builder = TransformBuilder(image_size=config['image_size'])

dataloaders = {
    'train': CxrDataLoader(args, split='train', transform=transform_builder.build('train')),
    'val': CxrDataLoader(args, split='val', transform=transform_builder.build('val')),
    'test': CxrDataLoader(args, split='test', transform=transform_builder.build('test')),
}

In [5]:
model = blip.blip_pretrain(
    pretrained=config['pretrained'],
    image_size=config['image_size'], 
    vit=config['vit'], 
    vit_grad_ckpt=config['vit_grad_ckpt'], 
    vit_ckpt_layer=config['vit_ckpt_layer'], 
    queue_size=config['queue_size'],
    max_length=config['max_length'],
)

/embeddings/word_embeddings is tied
/embeddings/position_embeddings is tied
/embeddings/LayerNorm is tied
/encoder/layer/0/crossattention/self/query is tied
/encoder/layer/0/crossattention/self/key is tied
/encoder/layer/0/crossattention/self/value is tied
/encoder/layer/0/crossattention/output/dense is tied
/encoder/layer/0/crossattention/output/LayerNorm is tied
/encoder/layer/0/intermediate/dense is tied
/encoder/layer/0/output/dense is tied
/encoder/layer/0/output/LayerNorm is tied
/encoder/layer/1/crossattention/self/query is tied
/encoder/layer/1/crossattention/self/key is tied
/encoder/layer/1/crossattention/self/value is tied
/encoder/layer/1/crossattention/output/dense is tied
/encoder/layer/1/crossattention/output/LayerNorm is tied
/encoder/layer/1/intermediate/dense is tied
/encoder/layer/1/output/dense is tied
/encoder/layer/1/output/LayerNorm is tied
/encoder/layer/2/crossattention/self/query is tied
/encoder/layer/2/crossattention/self/key is tied
/encoder/layer/2/crossat

In [6]:
trainer = BLIPTrainer(
    model,
    config,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    test_loader=dataloaders['test'],
    mixed_precision=False
)

2025-04-26 20:28:37,560 - INFO - Total parameters: 475,892,799
2025-04-26 20:28:37,561 - INFO - Trainable parameters: 252,441,919 (53.05%)


In [8]:
trainer.train()