In [None]:
# please use train_imagen.py to train the model!

In [None]:
from models.imagen_pytorch.imagen_pytorch import Unet, Imagen
from models.imagen_pytorch.trainer import ImagenTrainer
from models.imagen_pytorch.data import NLMCXRDataset
import os

In [None]:
iterations = 100000 # not batch_num now!
batch_size = 16
validate_every = 100
sample_every = 1000
save_every = 1000
sample_save_path = '../tmp/'
image_path = '/home/guo/git/Rad-ReStruct/data/radrestruct/images/'
text_path = '/home/guo/data/ecgen-radiology/'
save_path = '../checkpoints/'
num_save_checkpoint = 3

load_model = False
load_model_path = '../checkpoints/imagen-1.pth'

unet_number = 1

In [None]:
# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text = True,  # this must be set to False for unconditional Imagen
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    timesteps = 1000,
    channels=1,
    cond_drop_prob = 0.1
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = False, # whether to split the validation dataset from the training
).cuda()

if (load_model):
    trainer.load(load_model_path)

# instantiate the train, validation and datasets, which return the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks.
dataset_train = NLMCXRDataset(image_path, text_path, image_size=64, mode='train')
dataset_val = NLMCXRDataset(image_path, text_path, image_size=64, mode='val')
dataset_test = NLMCXRDataset(image_path, text_path, image_size=64, mode='test')

trainer.add_train_dataset(dataset_train, batch_size = batch_size)
trainer.add_valid_dataset(dataset_val, batch_size = batch_size)

# working training loop

for i in range(iterations):
    loss = trainer.train_step(unet_number = unet_number)
    # print(f'loss: {loss}')

    if not (i % validate_every):
        valid_loss = trainer.valid_step(unet_number = unet_number)
        print(f'valid loss in iter {i}: {valid_loss}')

    if not (i % sample_every) and trainer.is_main: # is_main makes sure this can run in distributed
        idx = 0
        images = trainer.sample(batch_size = 1, return_pil_images = True, text_embeds=dataset_test[idx][1].unsqueeze(0), stop_at_unet_number=unet_number) # returns List[Image]
        images[0].save(sample_save_path + f'sample-{i // sample_every}.png')
        print(f'saved sample {i // sample_every}')

    if not (i % save_every) and trainer.is_main:
        model_name = f'imagen-{i // save_every}.pt'
        trainer.save(save_path + model_name)
        print(f'saved model {model_name}')
        
        if i > save_every:
            if f'imagen-{(i // save_every) - num_save_checkpoint}.pt' in os.listdir(save_path):
                os.remove(save_path + f'imagen-{(i // save_every) - num_save_checkpoint}.pt')
