In [None]:
import torch
from data_processing import make_gpu_dataloader
import warnings
from unet import train_model, CatDiffusion, DiffusionConfig
from unet_utils import tensor_to_image, visualize_denoising, compute_fid
import matplotlib.pyplot as plt
import os
import numpy as np
import random
warnings.filterwarnings("ignore")
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False    
set_seed(42)

In [None]:
def save_samples(samples, path_prefix):
    os.makedirs(path_prefix, exist_ok=True)
    for i, sample in enumerate(samples):
        img = tensor_to_image(sample)
        img.save(os.path.join(path_prefix, f"sample{i}.png"))
        
def save_loss(loss_list, path_prefix, param, val):
    loss_save = np.asarray(loss_list)
    np.save(f"{path_prefix}/{param}_{val}", loss_save)

In [None]:
class Config:
        root_path = 'Data'
        image_size = 64
        batch_size = 128
        device = "cuda"

lr = 0.0001
wd = 0.001
timesteps = 500
epochs = 10

cfg = Config()
loader_cats = make_dataloader(cfg)
model, loss = train_model(loader_cats, epochs=epochs, lr=lr, weight_decay=wd, timesteps=timesteps, device=cfg.device)

In [None]:
visualize_denoising(model)