In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

Download CelebA dataset and eval txt file

In [None]:
import os
import requests
import zipfile
import urllib.request

data_dir = "/content/drive/MyDrive/"
zip_dir = "/content/"
os.makedirs(zip_dir, exist_ok=True)

zip_path = os.path.join(data_dir, "celeba-hq-256.zip")
extracted_path = os.path.join(zip_dir, "celeba-hq-256")
os.makedirs(extracted_path, exist_ok=True)

data_dir = extracted_path

if True:
    print("Extracting files...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extracted_path)
    print("Extraction complete!")
else:
    print("Dataset already downloaded and extracted!")

print(f"Images are in: {extracted_path}")


In [None]:
import os
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class CelebADataset(Dataset):
    def __init__(self, data_dir, partition_file, split='train', transform=None):
        self.data_dir = data_dir
        self.transform = transform

        with open(partition_file, 'r') as f:
            lines = f.readlines()

        self.image_labels = []
        for line in lines:
            image_name, partition = line.strip().split()
            partition = int(partition)  # 0: train, 1: val, 2: test
            if (split == 'train' and partition == 0) or \
               (split == 'val' and partition == 1) or \
               (split == 'test' and partition == 2):
                self.image_labels.append(image_name)

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_dir, self.image_labels[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image

class CelebAHQ256(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        file_list = os.listdir(data_dir)

        self.img_list = []
        for file in file_list:
            if file.endswith('.jpg'):
                file_num = int(file.split('.')[0])
                if 0 <= file_num <= 29999:
                    self.img_list.append(file)

        self.img_list.sort()  # 00000~29999 순서로 정렬

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_dir, self.img_list[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image

# transform 정의
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)
])

# 데이터셋 생성
train_dataset = CelebAHQ256(data_dir, transform=transform)

In [None]:
import torch, math
from torch import nn
import torch.nn.functional as F
from einops import *
from einops.layers.torch import Rearrange
from functools import partial

class SinusoidalPosEmb(nn.Module):
    def __init__(self, time_dim, theta=10000):
        super().__init__()
        self.time_dim = time_dim
        self.theta = theta

    def forward(self, time):
        device = time.device
        half_dim = self.time_dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class RMSNorm(nn.Module):
    """
    Modified version of https://arxiv.org/abs/1910.07467
    """
    def __init__(self, dim):
        super(RMSNorm, self).__init__()
        self.sqrt_dim = math.sqrt(dim)
        self.gamma = nn.Parameter(torch.ones((1,dim,1,1)))

    def forward(self, x):
        normalized_x = F.normalize(x, dim=1)
        return normalized_x * self.gamma * self.sqrt_dim

class ResBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ResBlock, self).__init__()
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1)
        self.norm = RMSNorm(dim_out)
        self.act = nn.SiLU()

    def forward(self, x, time_cond=None):
        x = self.conv(x)
        x = self.norm(x)

        if time_cond is not None:
            scale, shift = time_cond
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class Resnet(nn.Module):
    def __init__(self, dim_in, dim_out, time_dim):
        super(Resnet, self).__init__()
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim, dim_out * 2)
        )

        self.block1 = ResBlock(dim_in, dim_out)
        self.block2 = ResBlock(dim_out, dim_out)
        self.proj_conv = nn.Conv2d(dim_in, dim_out, 1)

    def forward(self, x, time_emb):
        time_cond = self.time_mlp(time_emb)
        time_cond = rearrange(time_cond, 'n c -> n c 1 1')
        time_cond = time_cond.chunk(2, dim=1)


        f_x = self.block1(x, time_cond)
        f_x = self.block2(f_x)

        return f_x + self.proj_conv(x)#Residual

class Attention(nn.Module):
    def __init__(self, dim, heads=4, channels_per_head=32):
        super(Attention, self).__init__()
        self.scale = dim ** -0.5
        self.heads = heads
        self.channels_per_head = channels_per_head

        self.norm = RMSNorm(dim)
        self.to_qkv = nn.Conv2d(dim, heads * channels_per_head * 3, 1, bias=False)
        self.conv = nn.Conv2d(heads * channels_per_head, dim, 1)

    def forward(self, x):
        N, C, H, W = x.shape

        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)

        score = torch.einsum("b h d i, b h d j -> b h i j", q, k) * self.scale
        score = score.softmax(dim=-1)

        attn = torch.einsum("b h i j, b h d j -> b h i d", score, v)
        attn = rearrange(attn, 'b h (x y) d -> b (h d) x y', x=H, y=W)
        return self.conv(attn)

class Upsample(nn.Module):
    def __init__(self, dim_in, dim_out, use_deconv=False):
        super(Upsample, self).__init__()
        self.up = None

        if use_deconv:
            self.up = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1)
        else:
            self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
                                    nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1))

    def forward(self, x):
        return self.up(x)

class Downsample(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.down = nn.Sequential(Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
                                  nn.Conv2d(dim_in * 4, dim_out, 1))

    def forward(self, x):
        return self.down(x)

class Unet(nn.Module):
    def __init__(self,
                 in_channel=3,
                 dim1=64, #first channel
                 dim_mults=(1, 2, 4, 8, 16), #Down: n-2 | Up: n-2
                 emb_theta=10000,
                 use_deconv=False,
                 attn_heads=4,
                 attn_channels_per_head=32,
                 attn_num=None):
        super(Unet, self).__init__()

        # Dimensions
        dim_multed = list(map(lambda t: t * dim1, dim_mults))
        mid_dim = dim_multed[-1]

        # Time embeddings
        time_dim = dim1 * 4
        time_encoder = SinusoidalPosEmb(dim1, theta=emb_theta)
        self.time_mlp = nn.Sequential(
            time_encoder,
            nn.Linear(dim1, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # NeuralNets
        ResNet = partial(Resnet, time_dim=time_dim)
        Attn = partial(Attention, heads=attn_heads, channels_per_head=attn_channels_per_head)
        self.front_conv = nn.Conv2d(in_channel, dim1, kernel_size=7, padding=3)

        self.downsamples = nn.ModuleList([])
        self.upsamples = nn.ModuleList([])

        # Set attention block number
        dim_length = len(dim_multed[1:])

        if attn_num is None:
            attn_down_idx = 0
            attn_up_idx = dim_length-1
        else:
            attn_down_idx = dim_length - attn_num
            attn_up_idx = attn_num - 1

        # Downsample
        for dim_in, dim_out, idx in zip(dim_multed[:-1], dim_multed[1:], range(0,dim_length)):
            last = (dim_length-1 == idx)

            self.downsamples.append(nn.ModuleList([
                ResNet(dim_in, dim_in),
                ResNet(dim_in, dim_in),
                Attn(dim_in) if idx >= attn_down_idx else nn.Identity(),
                Downsample(dim_in, dim_out) if not last else nn.Conv2d(dim_in, dim_out, kernel_size=3, padding=1)
            ]))

        # Upsample
        for dim_in, dim_out, idx in zip(*map(reversed, (dim_multed[1:], dim_multed[:-1], )), range(0,dim_length)):
            last = (dim_length-1 == idx)

            self.upsamples.append(nn.ModuleList([
                ResNet(dim_in + dim_out, dim_in),
                ResNet(dim_in + dim_out, dim_in),
                Attn(dim_in) if idx <= attn_up_idx else nn.Identity(),
                Upsample(dim_in, dim_out, use_deconv) if not last else nn.Conv2d(dim_in, dim_out, kernel_size=3, padding=1)
            ]))

        self.mid_resnet1 = ResNet(mid_dim, mid_dim)
        self.mid_attn1 = Attn(mid_dim)
        self.mid_resnet2 = ResNet(mid_dim, mid_dim)

        self.end_resnet1 = nn.Conv2d(dim1*2, dim1, 1)
        self.end_conv = nn.Sequential(nn.Conv2d(dim1, 64, 1), nn.Conv2d(64, in_channel, 1))

    def forward(self, x, time):
        t = self.time_mlp(time)
        x = self.front_conv(x)
        r = x.clone()

        skip_con_buf = []

        for net1, net2, attn, down in self.downsamples:
            # Block1
            x = net1(x, t)
            skip_con_buf.append(x)

            # Block2
            x = net2(x, t)
            if not isinstance(attn, nn.Identity): x = attn(x) + x # Attension block is just for score. so we need to add x to score.
            skip_con_buf.append(x)

            # Downsampling
            x = down(x)

        x = self.mid_resnet1(x, t)
        x = self.mid_attn1(x)
        x = self.mid_resnet2(x, t)

        for net1, net2, attn, up in self.upsamples:
            # Block1
            x = torch.cat((x, skip_con_buf.pop()), dim=1)
            x = net1(x, t)

            # Block2
            x = torch.cat((x, skip_con_buf.pop()), dim=1)
            x = net2(x, t)
            if not isinstance(attn, nn.Identity): x = attn(x) + x

            # Upsampling
            x = up(x)

        x = torch.cat((x, r), dim=1)
        x = self.end_resnet1(x)
        x = self.end_conv(x)
        return x

In [None]:
import PIL, torch, PIL.Image
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np

img_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t*2)-1)
])

tensor_to_img = transforms.Compose([
    transforms.Lambda(lambda t: (t+1)/2),
    transforms.ToPILImage()
])

class ImageManager:
    def __init__(self, path, format='png', use_clamp=True, transform_i_to_t=img_to_tensor, transform_t_to_i=tensor_to_img):
        self.path = path
        self.format = format
        self.use_clamp = use_clamp
        self.transform_i_to_t = transform_i_to_t
        self.transform_t_to_i = transform_t_to_i

        if use_clamp:
            self.transform_adjust_range = transforms.Lambda(lambda x: torch.clamp(x, -1.0, 1.0))
        else:
            self.transform_adjust_range = transforms.Lambda(lambda x: ((x - x.amin(dim=(-3, -2, -1), keepdim=True)) / (x.amax(dim=(-3, -2, -1), keepdim=True) - x.amin(dim=(-3, -2, -1), keepdim=True)) - 0.5) * 2)

        # Make dir
        os.makedirs(self.path, exist_ok=True)

    def save_plot(self, loss_hist, name, scatter=True):
          plt.clf()
          x = range(0,len(loss_hist))
          if scatter: plt.scatter(x, loss_hist)
          else: plt.plot(x, loss_hist)
          plt.savefig(f'{self.path}/{name}.png')
          plt.close()

    def get_image(self, name):
        img = PIL.Image.open(f'{self.path}/{name}.{self.format}')
        img_tensor = self.transform_i_to_t(img)
        return img_tensor

    def save_img(self, tensors, name):
        tensors = self.transform_adjust_range(tensors)

        if tensors.dim() == 4:
            idx = 0
            for tensor in tensors:
                img = self.transform_t_to_i(tensor)
                img.save(f'{self.path}/{name}-{idx}.png')
                idx += 1

        else:
            img = self.transform_t_to_i(tensors)
            img.save(f'{self.path}/{name}.png')

    def plot_images_with_spacing(self, tensors, name):
        plt.clf()
        N, B, C, H, W = tensors.shape

        fig, axes = plt.subplots(B, N, figsize=(N * 3, 3 * B))

        if B == 1:
            axes = [axes]
        elif N == 1:
            axes = [[ax] for ax in axes]

        for b_idx in range(B):
            for n_idx in range(N):
                tensor = tensors[n_idx, b_idx, :, :, :].unsqueeze(0)
                tensor = self.transform_adjust_range(tensor).squeeze()
                img = tensor_to_img(tensor)
                axes[b_idx][n_idx].imshow(np.array(img))
                axes[b_idx][n_idx].axis('off')

        plt.tight_layout()
        plt.savefig(f'{self.path}/{name}.png')
        plt.close()


In [None]:
import math
from torch.optim.lr_scheduler import _LRScheduler

"""
https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py
"""

class CosineAnnealingWarmupRestarts(_LRScheduler):
    """
        optimizer (Optimizer): Wrapped optimizer.
        first_cycle_steps (int): First cycle step size.
        cycle_mult(float): Cycle steps magnification. Default: -1.
        max_lr(float): First cycle's max learning rate. Default: 0.1.
        min_lr(float): Min learning rate. Default: 0.001.
        warmup_steps(int): Linear warmup step size. Default: 0.
        gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
        last_epoch (int): The index of last epoch. Default: -1.
    """

    def __init__(self,
                 optimizer : torch.optim.Optimizer,
                 first_cycle_steps : int,
                 cycle_mult : float = 1.,
                 max_lr : float = 0.1,
                 min_lr : float = 0.001,
                 warmup_steps : int = 0,
                 gamma : float = 1.,
                 last_epoch : int = -1
        ):
        assert warmup_steps < first_cycle_steps

        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle

        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle

        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)

        # set learning rate min_lr
        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch

        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

In [None]:
import torch, time, math, os
from torch.optim import Adam
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

def linear_beta_scheduler(beta_start, beta_end, timesteps):
    return torch.linspace(beta_start, beta_end, timesteps)

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

class GaussianDiffusion(nn.Module):
    def __init__(self,
                 model,
                 beta_start=0.0001,
                 beta_end=0.02,
                 timesteps=1000,
                 device=device,
                 save_per_p_steps=100
                 ):
        super(GaussianDiffusion, self).__init__()
        self.device = device
        self.timesteps = timesteps
        self.save_per_p_steps = save_per_p_steps

        # Basic components
        betas = linear_beta_scheduler(beta_start, beta_end, timesteps).to(self.device)
        alphas = 1. - betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value = 1.)
        sqrt_one_minus_alphas_bar = torch.sqrt(1. - alphas_bar)

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) #buffer for coefficients

        # Register basic components in buffer
        register_buffer('betas', betas)
        register_buffer('alphas_bar', alphas_bar)
        register_buffer('alphas_bar_prev', alphas_bar_prev)

        # Register coefficients components in buffer
        register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))
        register_buffer('sqrt_one_minus_alphas_bar', sqrt_one_minus_alphas_bar)
        register_buffer('one_minus_alphas', 1. - alphas)
        register_buffer('sqrt_alphas_recip', torch.sqrt(1/alphas))
        register_buffer('coef_pred_noise', betas / sqrt_one_minus_alphas_bar)

        posterior_variance = betas * (1. - alphas_bar_prev) / (1. - alphas_bar)

        register_buffer('sqrt_posterior_variance', torch.sqrt(posterior_variance))

        self.model = model.to(device)

    def q_sample(self, x, time, noise=None):
        """
        q(x_t | x_t-1) : forward process
        """
        x_shape = x.shape
        x = x.to(self.device)
        time = time.to(self.device)

        assert x_shape[0] == time.shape[0], 'mismatch of tensor shape'

        if noise is None:
            noise = torch.randn_like(x, device=self.device)

        sqrt_alphas_bar_t = extract(self.alphas_bar, time, x_shape)
        sqrt_one_minus_alphas_bar_t = extract(self.sqrt_one_minus_alphas_bar, time, x_shape)

        noisy_x = sqrt_alphas_bar_t * x + sqrt_one_minus_alphas_bar_t * noise

        return noisy_x

    @torch.inference_mode()
    def p_sample(self, x):
        """
        p(x_0 | x_T) : denoise process
        """
        N, C, H, W = x.shape
        xs = []

        for i in range(self.timesteps-1, -1, -1):
            t = torch.ones((N,), dtype=torch.int64, device=self.device) * i
            if (i+1) % self.save_per_p_steps == 0: xs.append(x)

            sqrt_alphas_recip_t = extract(self.sqrt_alphas_recip, t, x.shape)
            one_minus_alphas_t = extract(self.one_minus_alphas, t, x.shape)
            sqrt_one_minus_alphas_bar_t = extract(self.sqrt_one_minus_alphas_bar, t, x.shape)
            sqrt_posterior_variance_t = extract(self.sqrt_posterior_variance, t, x.shape)

            pred_noise = self.model(x, t)
            noise = torch.randn_like(x, device=self.device)

            x_mean = sqrt_alphas_recip_t * (x - pred_noise * one_minus_alphas_t / sqrt_one_minus_alphas_bar_t)
            x_var = sqrt_posterior_variance_t * noise

            x = x_mean + x_var

        xs.append(x)

        xs = torch.stack(xs)

        return xs

class Trainer:
    def __init__(self,
                 diffusion_model,
                 dataset,
                 save_path,
                 save_name,
                 lr_scheduler=None,  # 함수 (callable) or None
                 batch_size=16,
                 shuffle=True,
                 ep=10,
                 lr=1e-4,
                 betas=(0.9, 0.999),
                 accumulation_steps=1,
                 device='cuda',
                 sample_num=1,
                 save_interval_steps=500,
                 track_hist=False,
                 use_p2=True,
                 p2_gamma=0.5,
                 loss_func=F.mse_loss,
                 imagemanager=None):

        self.device = device
        self.timesteps = diffusion_model.timesteps
        self.accumulation_steps = accumulation_steps
        self.save_interval_steps = save_interval_steps

        # 기본 구성
        self.grad_update = 0
        self.total_train_time_prev = 0
        self.total_train_time = 0
        self.save_name = save_name
        self.save_path = save_path
        self.imagemanager = imagemanager
        self.ep = ep
        self.dataloader = DataLoader(dataset, batch_size, shuffle=shuffle)

        # 모델 및 옵티마이저
        self.diffusion_model = diffusion_model
        self.optim = Adam(self.diffusion_model.model.parameters(), lr=lr, betas=betas)

        # 스케줄러 함수 저장만
        self.lr_scheduler_fn = lr_scheduler
        self.lr_scheduler = None  # 실제 스케줄러는 load or train에서 생성

        self.use_p2 = use_p2
        self.p2_gamma = p2_gamma
        self.loss_func = loss_func
        self.loss = self.p2_weighting_loss if use_p2 else loss_func

        self.sample_num = sample_num
        self.track_hist = track_hist
        self.hist = []
        self.hist_mean = []

    def save(self):
        save_dict = {
            'model_state_dict': self.diffusion_model.model.state_dict(),
            'optim_state_dict': self.optim.state_dict(),
            'grad_update': self.grad_update,
            'total_train_time': self.total_train_time + self.total_train_time_prev
        }
        if self.lr_scheduler is not None:
            save_dict['lr_scheduler_state_dict'] = self.lr_scheduler.state_dict()

        torch.save(save_dict, f'{self.save_path}/{self.save_name}.pt')

    def load(self, current_step_for_scheduler=None):
        checkpoint = torch.load(f'{self.save_path}/{self.save_name}.pt', map_location=self.device)

        self.diffusion_model.model.load_state_dict(checkpoint['model_state_dict'])
        self.optim.load_state_dict(checkpoint['optim_state_dict'])
        self.grad_update = checkpoint['grad_update']
        self.total_train_time_prev = checkpoint['total_train_time']

        # 스케줄러 새로 초기화 (함수였다면)
        if self.lr_scheduler_fn is not None and self.lr_scheduler is None:
            self.lr_scheduler = self.lr_scheduler_fn(self.optim)

        # 스케줄러가 이미 있고 저장된 상태가 있으면 로드
        if self.lr_scheduler is not None and 'lr_scheduler_state_dict' in checkpoint:
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])

        # 없으면 step 위치라도 맞춰주자 (warmup start)
        elif self.lr_scheduler is not None and current_step_for_scheduler is not None:
            self.lr_scheduler.step(current_step_for_scheduler)

    def ensure_scheduler_ready(self, current_step_for_scheduler=None):
        # scheduler가 아직 정의되지 않은 경우 init (load 안 한 경우)
        if self.lr_scheduler is None and self.lr_scheduler_fn is not None:
            self.lr_scheduler = self.lr_scheduler_fn(self.optim)
            if current_step_for_scheduler is not None:
                self.lr_scheduler.step(current_step_for_scheduler)

    def p2_weighting_loss(self, pred_noise, noise, time):
        N, C, H, W = noise.shape
        alpha_bar_t = extract(self.diffusion_model.alphas_bar, time, noise.shape)
        p2_weight =  (1 - alpha_bar_t) ** self.p2_gamma
        loss = self.loss_func(pred_noise, noise, reduction='none').mean(dim=[1, 2, 3])
        loss = p2_weight * loss
        return loss.mean()

    def train(self):
        max_itera = len(self.dataloader)
        start_time = time.time()

        for ep in range(0, self.ep):
            process_bar = tqdm(zip(self.dataloader, range(0, max_itera)), desc=f"Training (Epoch {ep+1}/{self.ep})", ncols=1000)
            self.optim.zero_grad()
            for data, itera in process_bar:
                N, C, H, W = data.shape
                x = data

                # t, epsilon
                random_t = torch.randint(0, self.timesteps, (N,), device=self.device)
                noise = torch.randn_like(x, device=self.device)

                # q(x_t | x_0)
                noisy_x = self.diffusion_model.q_sample(x=x, time=random_t, noise=noise)

                # Predict Noise
                pred_noise = self.diffusion_model.model(noisy_x, random_t)

                # Loss
                if self.use_p2: loss = self.loss(pred_noise, noise, random_t) / self.accumulation_steps
                else: loss = self.loss(pred_noise, noise) / self.accumulation_steps

                # Process description
                process_bar.set_postfix(
                loss=str(float(loss))[:6],
                lr=f"{self.optim.param_groups[0]['lr']:.1e}",
                grad_update=self.grad_update
                )

                # Train
                loss.backward()
                if (itera+1) % self.accumulation_steps == 0:

                    # Track hist
                    if self.track_hist: self.hist.append(float(loss))

                    # Optimize
                    self.grad_update += 1
                    self.optim.step()
                    if self.lr_scheduler is not None: self.lr_scheduler.step()
                    self.optim.zero_grad()

                if (itera+1) % self.save_interval_steps == 0:
                    # Save model
                    self.total_train_time = time.time() - start_time
                    self.save()

                    if self.sample_num != 0 and (itera+1) % (self.save_interval_steps*3) == 0:
                        # Sampling : p(x_0 | x_T)
                        gaussian_noise = torch.randn((self.sample_num, C, H, W), device=self.device)
                        generated_samples = self.diffusion_model.p_sample(gaussian_noise)

                        # Save generated images
                        self.imagemanager.plot_images_with_spacing(generated_samples, f'p_sample_{ep}_{itera+1}')

                    # Track hist
                    if self.track_hist:
                        loss_mean = sum(self.hist) / len(self.hist)
                        self.hist_mean.append(loss_mean)

                        # Save hist
                        self.imagemanager.save_plot(self.hist_mean, 'mean_hist')
                        self.imagemanager.save_plot(self.hist, 'hist')


In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup
unet = Unet(in_channel=3, dim1=64, dim_mults=(1,2,2,4,4,6,6), attn_num=3)#1,2,2,4,4,8,8
diffusionModel = GaussianDiffusion(model=unet)
imageManager = ImageManager(path='/content/drive/MyDrive/results', use_clamp=False)

lr = 1.e-4

lr_scheduler = lambda optim: CosineAnnealingWarmupRestarts(
    optimizer=optim,
    first_cycle_steps=500,
    cycle_mult=2.0,
    max_lr=lr,
    min_lr=1e-6,
    warmup_steps=100,
    gamma=0.9
)

lr_scheduler = lambda optim: get_cosine_schedule_with_warmup(
    optimizer=optim,
    num_training_steps=500_000,
    num_warmup_steps=2_000
)

trainer = Trainer(diffusion_model=diffusionModel,
                  dataset=train_dataset,
                  save_path='/content/drive/MyDrive/hist',
                  save_name='Diffusion_celeba-hq_1000T_cosine',
                  batch_size=4,
                  lr_scheduler=lr_scheduler,
                  lr=lr,
                  loss_func=F.mse_loss,
                  accumulation_steps=4,
                  use_p2=False,
                  save_interval_steps=500,
                  sample_num=1,
                  ep=100,
                  imagemanager=imageManager,
                  track_hist=True)

trainer.load() # scheduler 넣고 첫 load일 때만 current step 맞춰주면 됨.
# trainer.ensure_scheduler_ready() # load하지 않고 lr_scheduler을 사용할 때
trainer.train()

Training (Epoch 1/100): 7500it [2:04:09,  1.01it/s, grad_update=63499, loss=0.0003, lr=1.0e-04]
Training (Epoch 2/100): 7500it [2:03:53,  1.01it/s, grad_update=65374, loss=0.0026, lr=1.0e-04]
Training (Epoch 3/100): 574it [08:58,  1.09it/s, grad_update=65518, loss=0.0047, lr=1.0e-04]