# Setting

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import math
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torchvision import transforms
import torchvision

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Model

In [None]:
def swish(x):
    # DDPM에서 쓰이는 활성화함수
    return x * torch.sigmoid(x)

def get_timestep_embedding(t,channel):
    # DDPM은 타임스텝 t 또한 입력으로 받아 이미지에 낀 노이즈의 값을 예측합니다.
    # get_timestep_embedding 함수는 batch_size만큼의 타임스텝 t를 입력으로 받아서,
    # 이에 대응되는 embedding vector(torch.Size([B,channel]))를 반환합니다.
    # Transformer의 sinusodial positional encoding과 동일합니다. (자세한 수식은 PPT 참고해주세요.)

    # channel 값은 짝수 int로 입력해야 코드가 잘 작동합니다.
    half = channel // 2
    device = t.device

    # (구현) timestep embedding
    inv_freq = torch.exp(-math.log(10000) * torch.arange(half)/half).to(device)
    args = t.float().unsqueeze(1) * inv_freq.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

    return emb

class GroupNorm(nn.GroupNorm):
    # Group normalization
    def __init__(self, num_channels, num_groups=8, eps=1e-6):
        super().__init__(num_groups, num_channels, eps=eps)

def conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True, init_scale=1.0):
    # 3x3 convolution
    conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=bias)
    with torch.no_grad():
        # init_scale==0으로 설정시 처음에 레이어의 가중치가 0으로 초기화됩니다.
        conv.weight.data *= init_scale
    return conv

def nin(in_ch, out_ch, init_scale=1.0):
    # 1x1 convolution
    layer = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
    with torch.no_grad():
        layer.weight.data *= init_scale
    return layer

def linear(in_features, out_features, init_scale=1.0):
    # linear layer
    fc = nn.Linear(in_features, out_features)
    with torch.no_grad():
        fc.weight.data *= init_scale
    return fc

class DownsampleBlock(nn.Module):
    # B*C*2H*2W 이미지를 입력으로 받아, B*C*H*W 이미지를 반환합니다.
    # with_conv 여부에 따라 downsampling 방식이 달라집니다.
    def __init__(self, channels, with_conv=True):
        super().__init__()
        if with_conv:
            self.op = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
        else:
            self.op = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        return self.op(x)

class UpsampleBlock(nn.Module):
    # B*C*H*W 이미지를 입력으로 받아, B*C*2H*2W 이미지를 반환합니다.
    # with_conv 여부에 따라 upsampling 방식이 달라집니다.
    def __init__(self, channels, with_conv=True):
        super().__init__()
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        if self.with_conv:
            x = self.conv(x)
        return x

class ResnetBlock(nn.Module):
    # DDPM의 주요 구성요소 중 하나인 ResnetBlock입니다.
    # 모델 구조 및 구현에 관한 세부사항은 PPT 참고바랍니다.
    # 모든 레이어의 init_scale은 1.0으로 설정합니다.
    def __init__(self, in_channels, out_channels, temb_channels=256, dropout=0.0):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.temb_channels = temb_channels # channel of timestep embedding
        self.dropout = dropout

        # (구현) ResnetBlock
        self.norm1 = GroupNorm(self.in_channels)
        self.conv1 = conv2d(self.in_channels, self.out_channels,
                            kernel_size=3, stride=1, padding=1, init_scale=1.0)
        self.temb_proj = linear(self.temb_channels, self.out_channels, init_scale=1.0)
        self.norm2 = GroupNorm(self.out_channels)
        self.conv2 = conv2d(self.out_channels, self.out_channels,
                            kernel_size=3, stride=1, padding=1, init_scale=1.0)
        self.conv_shortcut = nin(self.in_channels, self.out_channels)

    def forward(self, x, temb):
        # B*in_channels*H*W 크기의 텐서 x와 B*temb_channels 크기의 temb를 입력으로 받아서
        # B*out_channels*H*W 크기의 텐서를 반환합니다.

        # (구현) ResnetBlock
        h = self.norm1(x)
        h = swish(h)
        h = self.conv1(h)
        h_temb = swish(temb)
        h_temb = self.temb_proj(h_temb)
        h_temb = h_temb[:, :, None, None]
        h = h + h_temb
        h = self.norm2(h)
        h = swish(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = self.conv2(h)
        x = self.conv_shortcut(x)
        return x + h

class AttnBlock(nn.Module):
    # DDPM의 주요 구성요소 중 하나인 ResnetBlock입니다.
    # 모델 구조 및 구현에 관한 세부사항은 PPT 참고바랍니다.
    def __init__(self, channels):
        super().__init__()
        self.norm = GroupNorm(channels)
        self.q = nin(channels, channels)
        self.k = nin(channels, channels)
        self.v = nin(channels, channels)
        self.proj_out = nin(channels, channels, init_scale=0.0)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)

        # Q,K,V 텐서 생성
        q = self.q(h)
        k = self.k(h)
        v = self.v(h)

        q = q.view(B, C, H * W).permute(0, 2, 1)  # (B, T, C)
        k = k.view(B, C, H * W).permute(0, 2, 1)  # (B, T, C)
        v = v.view(B, C, H * W).permute(0, 2, 1)  # (B, T, C)

        scale = q.shape[-1] ** -0.5

        attn = torch.softmax(torch.bmm(q, k.transpose(1, 2)) * scale, dim=-1)
        h_ = torch.bmm(attn, v)
        h_ = h_.permute(0, 2, 1).reshape(B, C, H, W)

        h_ = self.proj_out(h_) # torch.Size([B,C,H,W])
        return x + h_

class DDPMModel(nn.Module):
    # 최종 DDPM 모델입니다.
    # Downsample, Middle, Upsample 블락으로 구성됩니다.
    # 각 블락은 resnetblock, attentionblock으로 구성됩니다.
    # 모델 구조 및 구현에 관한 세부사항은 PPT 참고바랍니다.
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        ch=64,
        ch_mult=(1,2,4),
        num_res_blocks=2,
        attn_resolutions={32},
        dropout=0.0,
        resamp_with_conv=False,
        init_resolution=64
    ):
        super().__init__()
        self.ch = ch
        self.ch_mult = ch_mult
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.dropout = dropout
        self.num_levels = len(ch_mult)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.resamp_with_conv = resamp_with_conv
        self.init_resolution = init_resolution

        # Timestep embedding channel
        self.temb_ch = ch * 4

        # Timestep embedding layers
        self.temb_dense0 = linear(self.ch, self.temb_ch)
        self.temb_dense1 = linear(self.temb_ch, self.temb_ch)

        # Input conv
        self.conv_in = conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)

        # (구현) Downsample blocks
        self.down_blocks = nn.ModuleList()

        curr_ch = ch
        curr_res = init_resolution
        for level in range(self.num_levels):
            level_blocks = nn.ModuleList()
            out_ch = ch * ch_mult[level]
            for i in range(num_res_blocks):
                level_blocks.append(ResnetBlock(curr_ch, out_ch, temb_channels=self.temb_ch, dropout=dropout))
                if curr_res in attn_resolutions:
                    level_blocks.append(AttnBlock(out_ch))
                curr_ch = out_ch
            self.down_blocks.append(level_blocks)
            if level != self.num_levels - 1:
                self.down_blocks.append(DownsampleBlock(curr_ch, with_conv=resamp_with_conv))
                curr_res //= 2

        # (구현) Middle blocks
        self.mid_block = nn.ModuleList()

        self.mid_block = nn.ModuleList([
            ResnetBlock(curr_ch, curr_ch, temb_channels=self.temb_ch, dropout=dropout),
            AttnBlock(curr_ch),
            ResnetBlock(curr_ch, curr_ch, temb_channels=self.temb_ch, dropout=dropout)
        ])

        # (구현) Upsample blocks
        self.up_blocks = nn.ModuleList()

        self.up_blocks = nn.ModuleList()
        for level in reversed(range(self.num_levels)):
            level_blocks = nn.ModuleList()
            out_ch = ch * ch_mult[level]
            level_blocks.append(ResnetBlock(curr_ch + out_ch, out_ch, temb_channels=self.temb_ch, dropout=dropout))
            if (init_resolution // (2 ** level)) in attn_resolutions:
                level_blocks.append(AttnBlock(out_ch))
            curr_ch = out_ch
            for i in range(num_res_blocks):
                level_blocks.append(ResnetBlock(curr_ch, curr_ch, temb_channels=self.temb_ch, dropout=dropout))
                if (init_resolution // (2 ** level)) in attn_resolutions:
                    level_blocks.append(AttnBlock(curr_ch))
            if level != 0:
                level_blocks.append(UpsampleBlock(curr_ch, with_conv=resamp_with_conv))
            self.up_blocks.append(level_blocks)

        # Output conv
        self.norm_out = GroupNorm(curr_ch)
        self.conv_out = conv2d(curr_ch, out_channels, kernel_size=3, stride=1, padding=1, init_scale=0.0)

    def forward(self, x, t):
        # Timestep embedding
        temb = get_timestep_embedding(t, self.ch) # B*ch
        temb = self.temb_dense0(temb) # B*4ch
        temb = swish(temb)
        temb = self.temb_dense1(temb) # B*4ch

        skips = []
        h = self.conv_in(x)

        # (구현) down_blocks에 있는 레이어들을 따라 downsampling 진행
        # 각 레이어의 output을 skips 리스트에 저장합니다. (이는 Upsample시 활용됩니다)
        # resnetblock과 attentionblock에 들어가는 input이 다름을 주의해주세요.
        down_iter = iter(self.down_blocks)
        for level in range(self.num_levels):
            blocks = next(down_iter)
            for layer in blocks:
                h = layer(h, temb) if isinstance(layer, ResnetBlock) else layer(h)
            skips.append(h)
            if level != self.num_levels - 1:
                downsample = next(down_iter)
                h = downsample(h)

        # (구현) mid_blocks에 있는 레이어들을 따라 진행
        for layer in self.mid_block:
            h = layer(h, temb) if isinstance(layer, ResnetBlock) else layer(h)

        # (구현) up_blocks에 있는 레이어들을 따라 upsampling 진행
        # downsample 과정에서 구한 텐서와 병합 후 레이어에 넣어줍니다.
        for level in range(self.num_levels):
            blocks = self.up_blocks[level]
            skip = skips.pop()
            h = torch.cat([h, skip], dim=1)
            h = blocks[0](h, temb)
            for layer in blocks[1:]:
                if isinstance(layer, ResnetBlock):
                    h = layer(h, temb)
                else:
                    h = layer(h)

        # Output
        h = self.norm_out(h)
        h = swish(h)
        h = self.conv_out(h)
        return h

# Dataset

In [None]:
# CelebA 압축파일 Colab 런타임 VM으로 복사
# 디렉토리에 celeba_dataset.zip 파일이 있어야 합니다.
!cp /content/sample_directory/celeba_dataset.zip /content/

# VM에 파일 압축풀기
# /content/data 경로에 모든 이미지가 저장됩니다.
!unzip -q /content/celeba_dataset.zip -d /content/data/

In [None]:
class CelebADataset(Dataset):
    # CelebA 데이터셋 구성
    def __init__(self, img_dir, img_num_list, transform=None):
        # img_num_list에 있는 번호의 이미지들만으로 구성
        self.img_paths = [os.path.join(img_dir, f'{i:06d}.jpg') for i in img_num_list]
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        path = self.img_paths[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

In [None]:
# train 및 validation에 쓰이는 이미지 개수는 30000개로 제한합니다.
num_data = 30000
train_test_ratio = 0.9
batch_size = 64

train_transform = transforms.Compose([
    transforms.Resize([64,64]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])
test_transform = transforms.Compose([
    transforms.Resize([64,64]),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

train_img_num = np.arange(1,num_data+1)[:int(num_data*train_test_ratio)]
test_img_num = np.arange(1,num_data+1)[int(num_data*train_test_ratio):]

train_dataset = CelebADataset('/content/data/img_align_celeba', img_num_list=train_img_num, transform=train_transform)
test_dataset = CelebADataset('/content/data/img_align_celeba', img_num_list=test_img_num, transform=test_transform)
print(f'Train dataset size: {len(train_dataset)}') # Ex) 27000
print(f'Test dataset size: {len(test_dataset)}') # Ex) 3000

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

# Train

In [None]:
def get_beta_alpha_linear(beta_start=0.0001, beta_end=0.02, num_timesteps=1000):
    # DDPM 학습 및 샘플링에 쓰일 alpha, beta, alphas_cumprod 반환

    betas = np.linspace(beta_start, beta_end, num_timesteps, dtype=np.float32)
    betas = torch.tensor(betas)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    return betas, alphas, alphas_cumprod

def q_sample(x0, t, noise, alphas_cumprod):
    # 원본 이미지 x0를 입력으로 받아 t스텝만큼 diffusion process 진행

    alpha_bar = alphas_cumprod[t-1].to(x0.device) # torch.Size([batch_size])
    B = alpha_bar.size(0)
    alpha_bar = alpha_bar.reshape((B,1,1,1))

    # (구현) t스텝만큼의 diffusion process 구현
    # 각 스텝마다 diffusion process를 반복문으로 구현하는 것이 아닙니다.
    # x_t가 x0를 평균으로 하는 정규분포를 따름을 이용하시면 됩니다.
    sqrt_alpha_bar = torch.sqrt(alpha_bar)
    sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - alpha_bar)
    x_t = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise

    return x_t

In [None]:
def compute_mse_loss(model, x_t, t, eps):
    # 모델이 예측한 이미지에 낀 노이즈의 값과, 실제 노이즈의 값의 차이를 반환

    pred_eps = model(x_t, t)
    loss = F.mse_loss(pred_eps, eps)
    return loss

In [None]:
def train_epoch(model,train_loader,alphas_cumprod,device,optimizer,use_gradient_clipping=True):
    # epoch별 학습 코드
    # 현재 epoch의 train loss 반환

    train_loss_sum = 0.0
    train_loss_cnt = 0

    model.train()
    for image in tqdm(train_loader, desc='Training'):
        image = image.to(device)
        eps = torch.randn(image.shape).to(device)
        # torch.Size([batch_size])의 t 텐서 생성
        t = torch.randint(1, 1001, (image.size(0),), dtype=torch.long).to(device)

        # (구현) image(x0), t, eps 이용해서 t스텝만큼 노이즈 추가된 이미지 x_t 생성
        x_t = q_sample(image, t, eps, alphas_cumprod)

        # (구현) 적절한 loss 값 구현 (전에 구현한 함수 이용)
        loss = compute_mse_loss(model, x_t, t, eps)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # gradient_clipping 사용시 학습의 안정성이 올라갑니다.
        if use_gradient_clipping:
            max_grad_norm = 1.0
            clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        train_loss_sum += loss.item()
        train_loss_cnt += 1

    train_loss = train_loss_sum / train_loss_cnt
    return train_loss

def test_epoch(model, test_loader, device):
    # epoch별 test 코드
    # 현재 epoch의 test loss 반환

    test_loss_sum = 0.0
    test_loss_cnt = 0

    model.eval()
    with torch.no_grad():
        for image in tqdm(test_loader, desc='Evaluating'):
            image = image.to(device)
            eps = torch.randn(image.shape).to(device)
            t = torch.randint(1, 1001, (image.size(0),), dtype=torch.long).to(device)

            # (구현) image(x0), t, eps 이용해서 t스텝만큼 노이즈 추가된 이미지 x_t 생성
            # train 부분과 동일
            x_t = q_sample(image, t, eps, alphas_cumprod)

            # (구현) 적절한 loss 값 구현 (전에 구현한 함수 이용)
            # train 부분과 동일
            loss = compute_mse_loss(model, x_t, t, eps)

            test_loss_sum += loss.item()
            test_loss_cnt += 1

        test_loss = test_loss_sum / test_loss_cnt
    return test_loss

def train(model,train_loader,test_loader,alphas_cumprod,device,optimizer,num_epochs,use_gradient_clipping=True):
    train_losses = []
    test_losses = []

    for epoch in tqdm(range(1, num_epochs+1)):
        train_loss = train_epoch(model, train_loader, alphas_cumprod, device, optimizer, use_gradient_clipping)
        test_loss = test_epoch(model, test_loader, device)
        train_losses.append(train_loss)
        test_losses.append(test_loss)

        print(f"Epoch {epoch}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

    return train_losses, test_losses

In [None]:
# 모델 하이퍼파라미터 권장 설정값
ch = 64
ch_mult = (1, 2, 4)
num_res_blocks = 2
attn_resolutions = {32}
dropout = 0.0
resamp_with_conv = False

In [None]:
device = torch.device('cuda:0'if torch.cuda.is_available() else 'cpu')

_, _, alphas_cumprod = get_beta_alpha_linear()
alphas_cumprod = alphas_cumprod.to(device)

model = DDPMModel(
    ch=ch,
    ch_mult=ch_mult,
    num_res_blocks=num_res_blocks,
    attn_resolutions=attn_resolutions,
    dropout=dropout,
    resamp_with_conv=resamp_with_conv,
    init_resolution=64
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=2e-3, weight_decay=1e-4)

model_path = '/content/drive/MyDrive/cv프로젝트/부트캠프/_90.pth'
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for param_group in optimizer.param_groups:
    param_group['lr'] = 1e-6

if __name__ == '__main__':
    train_losses, test_losses = train(model, train_loader, test_loader, alphas_cumprod, device, optimizer, 10, True)

checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, '/content/drive/MyDrive/cv프로젝트/부트캠프/_100.pth')