# Setup

In [21]:
import create_models
from topodiff.resample import UniformSampler
from topodiff.image_datasets_diffusion_model import load_data
from topodiff.fp16_util import MixedPrecisionTrainer
import torch
import copy

# Model

In [None]:
image_size = 64
mean_variance_use_fp16 = True

mean_variance = create_models.mean_variance()
print(mean_variance)

diffusion = create_models.gaussian_diffusion(
    steps=1000,
    learn_sigma=True,
    sigma_small=False,
    noise_schedule="cosine",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing="",
)
schedule_sampler = UniformSampler(diffusion)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
mean_variance = mean_variance.to(device)

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(6, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-3): 3 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=256, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.3, inplace=False)
          (3): Conv2d(128, 128, kernel_size=(3, 3)

# Data

In [None]:
batch_size = 32

data = load_data(
  data_dir=r".\data\dataset_1_diff\training_data",
  batch_size=batch_size,
  image_size=image_size,
)

# Optimizing the Model Parameters

In [None]:
epochs = 10
trainer = MixedPrecisionTrainer(mean_variance, mean_variance_use_fp16)
optimizer = torch.optim.AdamW(mean_variance.parameters(), lr=1e-4)
ema_rate = 0.9999
ema_parameters = copy.deepcopy(trainer.master_params)

for epoch in range(epochs):
    images: torch.Tensor
    constraints: torch.Tensor
    for images, constraints, _ in data:
        images = images.to(device)
        constraints = constraints.to(device)
        time_steps, weights = schedule_sampler.sample(batch_size, device)
        loss = (diffusion.training_losses(mean_variance, images, constraints, time_steps)["loss"] * weights).mean()
        
        trainer.zero_grad()
        trainer.backward(loss)
        if trainer.optimize(optimizer):
            target: torch.Tensor
            source: torch.Tensor
            for target, source in zip(ema_parameters, trainer.master_params):
                target.detach().mul_(ema_rate).add_(source, alpha=1 - ema_rate)
    