In [1]:
from models.unet import Unet
from models.diffusion_model import DiffusionModelTest
import copy
import torch
from utils.trainer import Trainer
from utils.model_utils import load_config_from_yaml

In [2]:
# Loading the model and configurations and 
checkpoint = torch.load('checkpoints/checkpoint_1.pth.tar')
config = load_config_from_yaml("configs/config.yaml")

config_unet = config['model']
config_diffusion_model = config['diffusion']

# create instance of unet
unet_model = Unet(**config_unet)
unet_model.load_state_dict(checkpoint["unet_model_state"])

# create instance of ema mdoel
ema_model = (copy.deepcopy(unet_model).eval().requires_grad_(False))
ema_model.load_state_dict(checkpoint["ema_model_state"])

# create instance of DiffusionModel
diffusion_model = DiffusionModelTest(model=unet_model, **config_diffusion_model)
diffusion_model.load_state_dict(checkpoint["diffusion_model_state"])

n_samples = 10
classes = torch.arange(0, 10).to("cpu") 
samples = diffusion_model.sample(n_samples=n_samples, ema_model=None, classes=classes, cond_weight=2)
samples_ema = diffusion_model.sample(n_samples=n_samples, ema_model=ema_model, classes=classes, cond_weight=2)

self.out_dim 1
dims [32, 32, 64, 64]
final 32 1 Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))


Sampling Time Step:: 1it [00:01,  1.17s/it]
Sampling Time Step:: 1it [00:00,  1.59it/s]


In [3]:
# train again
optimizer = torch.optim.Adam(unet_model.parameters(), lr=config['optimizer']['lr'])
optimizer.load_state_dict(checkpoint["optimizer_state"])
trainer = Trainer(**config['trainer'], diffusion_model=diffusion_model, optimizer=optimizer)
trainer.nb_epochs = 5
trainer.start_epoch = checkpoint['epoch']
trainer.train_loop()


Epoch: 2


Training Loop:   0%|          | 0/938 [00:22<?, ?it/s]


Epoch 2 Loss: 1.0512707233428955
We sample the following classes: 
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


Sampling Time Step:: 1it [00:00,  1.32it/s]
Sampling Time Step:: 1it [00:00,  2.86it/s]


Epoch: 3


Training Loop:   0%|          | 0/938 [00:16<?, ?it/s]


Epoch 3 Loss: 1.0436534881591797
Epoch: 4


Training Loop:   0%|          | 0/938 [00:29<?, ?it/s]


Epoch 4 Loss: 1.046654224395752
We sample the following classes: 
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


Sampling Time Step:: 1it [00:00,  1.28it/s]
Sampling Time Step:: 1it [00:00,  2.84it/s]


Epoch: 5


Training Loop:   0%|          | 0/938 [00:15<?, ?it/s]


Epoch 5 Loss: 1.0216786861419678
We sample the following classes: 
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


Sampling Time Step:: 1it [00:00,  1.39it/s]
Sampling Time Step:: 1it [00:00,  2.99it/s]


In [7]:
# train model from beginning with same configurations: to reproduce results
# makes no use of existing states
config = load_config_from_yaml("configs/config.yaml")

config_unet = config['model']
config_diffusion_model = config['diffusion']

unet_model = Unet(**config_unet)
diffusion_model = DiffusionModelTest(model=unet_model, **config_diffusion_model)

optimizer = torch.optim.Adam(unet_model.parameters(), lr=config['optimizer']['lr'])
trainer = Trainer(**config['trainer'], diffusion_model=diffusion_model, optimizer=optimizer)
trainer.train_loop()

self.out_dim 1
dims [32, 32, 64, 64]
final 32 1 Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
Epoch: 1


Training Loop:   0%|          | 0/938 [00:26<?, ?it/s]


Epoch 1 Loss: 1.1690998077392578
We sample the following classes: 
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


Sampling Time Step:: 1it [00:01,  1.21s/it]
Sampling Time Step:: 1it [00:00,  1.20it/s]
