In [None]:
import yaml
import argparse

from data import load_tokenizer
from clipxrgen import build_model
from clipxrgen.loss import build_loss
from data.loaders import CxrDataLoader
from configs import load_config_from_file
from clipxrgen.downstream import RadiologyReportDecoderTrainer

In [None]:
tokenizer_config = yaml.load(open("../configs/tokenizer/cxr_bert.yaml"), Loader=yaml.FullLoader)
prompt_constructor_config = load_config_from_file("../configs/prompts/teacher_forcing.yaml")
model_config = load_config_from_file("../configs/model/conditional_report_gen.yaml")
loss_config = load_config_from_file("../configs/loss/language_modeling.yaml")

tokenizer = load_tokenizer(**tokenizer_config)
loss_fn = build_loss(loss_config)
model = build_model(model_config, tokenizer)
prompt_constructor = build_model(model_config, tokenizer)

In [7]:
transform_config = yaml.load(open("../configs/transform/clahe.yaml"), Loader=yaml.FullLoader)
train_config = yaml.load(open("../configs/train/report_gen.yaml"), Loader=yaml.FullLoader)

args = argparse.Namespace(
    dataset_name='mimic-cxr-mvs',
    batch_size=10,
    max_length=train_config["max_length"],
    image_size=train_config["image_size"],
    num_workers=0,
    drop_last=True,
    use_minio=False
)

In [8]:
dataloaders = {
    "train": CxrDataLoader(
        args,
        split='train',
        transform_config=transform_config,
        tokenizer=tokenizer,
    ),
    "val": CxrDataLoader(
        args,
        split='val',
        transform_config=transform_config,
        tokenizer=tokenizer,
    ),
    "test": CxrDataLoader(
        args,
        split='test',
        transform_config=transform_config,
        tokenizer=tokenizer,
    ),
}

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/tudormihaita/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/tudormihaita/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/tudormihaita/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [14]:
print(len(dataloaders["val"].dataset))

37664


In [None]:
trainer = RadiologyReportDecoderTrainer(
    model=model,
    config=train_config,
    loss_fn=loss_fn,
    prompt_constructor=prompt_constructor,
    train_loader=dataloaders["train"],
    val_loader=dataloaders["val"],
    mixed_precision=False
)

In [None]:
trainer.train()