In [None]:
from models.imagen_pytorch.imagen_pytorch import Unet, Imagen
from models.imagen_pytorch.trainer import ImagenTrainer
from models.imagen_pytorch.data import NLMCXRDataset
from models.imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

In [None]:
# unets for unconditional imagen

unet = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unet above

imagen = Imagen(
    condition_on_text = True,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 64,
    timesteps = 1000,
    channels=1,
    cond_drop_prob = 0.1
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = False, # whether to split the validation dataset from the training
).cuda()

# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training
texts = [f'example text {i}' for i in range(10)]

text_embeds, text_masks = t5_encode_text(texts, DEFAULT_T5_NAME, return_attn_mask=True)
dataset_train = NLMCXRDataset('/home/guo/git/Rad-ReStruct/data/radrestruct/images/', '/home/guo/data/ecgen-radiology/', image_size=64, mode='train')
dataset_val = NLMCXRDataset('/home/guo/git/Rad-ReStruct/data/radrestruct/images/', '/home/guo/data/ecgen-radiology/', image_size=64, mode='val')


trainer.add_train_dataset(dataset_train, batch_size = 16)
trainer.add_valid_dataset(dataset_val, batch_size = 16)

# working training loop

for i in range(200000):
    loss = trainer.train_step(unet_number = 1, max_batch_size = 64)
    # print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
        print(f'valid loss in {i}: {valid_loss}')

    if not (i % 500) and trainer.is_main: # is_main makes sure this can run in distributed
        images = trainer.sample(batch_size = 1, return_pil_images = True, text_embeds=text_embeds) # returns List[Image]
        images[0].save(f'../tmp/sample-{i // 100}.png')