In [None]:
from models.unet import Unet
from models.diffusion_model import DiffusionModel
import copy
import torch
from utils.trainer_utils import Trainer
from utils.data_utils import load_config_from_yaml, plot_figure

In [None]:
# Loading the model and configurations and 
checkpoint = torch.load('checkpoints/checkpoint_1.pth.tar')
config = load_config_from_yaml("configs/config.yaml")

config_unet = config['model']
config_diffusion_model = config['diffusion']

# create instance of unet
unet_model = Unet(**config_unet)
unet_model.load_state_dict(checkpoint["unet_model_state"])

# create instance of ema mdoel
ema_model = (copy.deepcopy(unet_model).eval().requires_grad_(False))
ema_model.load_state_dict(checkpoint["ema_model_state"])

# create instance of DiffusionModel
diffusion_model = DiffusionModel(model=unet_model, **config_diffusion_model)
diffusion_model.load_state_dict(checkpoint["diffusion_model_state"])

n_samples = 10
classes = torch.arange(0, 10).to("cpu") 

samples_ddim = diffusion_model.sample(n_samples=n_samples, ema_model=None, classes=classes, cond_weight=1, use_ddim=False, eta=0)
samples_ema_ddim = diffusion_model.sample(n_samples=n_samples, ema_model=ema_model, classes=classes, cond_weight=1, use_ddim=False, eta=0)


In [None]:
fig_eta0 = plot_figure(samples_ddim_eta0, n_samples)
fig_ema_eta0 = plot_figure(samples_ema_ddim_eta0, n_samples)

In [None]:
fig = plot_figure(samples_ddim, n_samples)
fig_ema = plot_figure(samples_ema_ddim, n_samples)

In [None]:
# train again
optimizer = torch.optim.Adam(unet_model.parameters(), lr=config['optimizer']['lr'])
optimizer.load_state_dict(checkpoint["optimizer_state"])
trainer = Trainer(**config['trainer'], diffusion_model=diffusion_model, optimizer=optimizer)
trainer.nb_epochs = 4
trainer.start_epoch = checkpoint['epoch']
trainer.train_loop()


In [None]:
# train model from beginning with same configurations: to reproduce results
# makes no use of existing states
config = load_config_from_yaml("configs/config.yaml")

config_unet = config['model']
config_diffusion_model = config['diffusion']

unet_model = Unet(**config_unet)
diffusion_model = DiffusionModel(model=unet_model, **config_diffusion_model)

optimizer = torch.optim.Adam(unet_model.parameters(), lr=config['optimizer']['lr'])
trainer = Trainer(**config['trainer'], diffusion_model=diffusion_model, optimizer=optimizer)
trainer.train_loop()