# NCSN, Noise Conditional Score Network
기존 Score Based Model에서 p(x)를 특정 trick을 통해 normal distribution으로 잡고 훈련.  
  
이전의 여러 문제점이 있었는데 추가로 perturbed data를 사용해서 극복하려고 함.  
  
\+ Anealed Langevin Sampling

In [1]:
! pip install einops

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1


In [2]:
import math
from functools import partial
from inspect import isfunction
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
import numpy as np

# ScoreNet (DDPM에서 사용했던 UNET 그대로 가져옴.)

In [3]:
def default(val, d) :
    if exists(val):
        return val
    return d() if isfunction(d) else d

def exists(x):
    return x is not None


class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# building block modules

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim_out, 3, padding = 1),
            nn.GroupNorm(groups, dim_out),
            nn.SiLU()
        )
    def forward(self, x):
        return self.block(x)

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        h = self.block1(x)

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, 'b c -> b c 1 1') + h

        h = self.block2(h)
        return h + self.res_conv(x)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            LayerNorm(dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
        q = q * self.scale

        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# model

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        with_time_emb = True,
        resnet_block_groups = 8,
        learned_variance = False
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # time embeddings

        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPosEmb(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim)
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Downsample(dim_out) if not is_last else nn.Identity()
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                block_klass(dim_out * 2, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Upsample(dim_in) if not is_last else nn.Identity()
            ]))

        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)

        self.final_conv = nn.Sequential(
            block_klass(dim, dim),
            nn.Conv2d(dim, self.out_dim, 1)
        )

    def forward(self, x, time): # time -> sigma label
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)


In [None]:

def extract(element, t): # t의 shape : (bs) batch size마다의 요소를 뽑아줘야한다.
    return torch.gather(element, 0, t)

# def annealed_langevin_sampling(x, alphas, sigmas, times, )
# long integer .. long 
class NCSN(nn.Module):
    def __init__(self, score_func, sigma_begin=1, sigma_end=0.01, sigma_steps=10, image_shape=(3, 64, 64), sampling_steps=200, step_lr=0.00004):
        super().__init__()
        self.score_func = score_func
        self.image_shape = image_shape
        self.step_lr = step_lr
        self.sampling_steps = sampling_steps
        sigmas = torch.tensor(np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), sigma_steps))).float()

        self.register_buffer('sigmas', sigmas)
        self.register_buffer("alphas", (sigmas/torch.full(sigmas.size(), sigmas[-1]))**2)


    def ncsn_loss(self, x, labels): # 사전에 1차적으로 perturbed 된 x
        bs = x.size(0)
        sigmas = extract(self.sigmas, labels).view(-1, 1, 1, 1) # (bs, 1, 1, 1)
        x_q = x + torch.randn_like(x) * sigmas # x_q
        # 훈련 시, noise가 들어간 x_q가 들어가줘야한다. 이 부분이 틀렸어서 훈련이 잘 안됐다.
        score = self.score_func(x_q, labels).view(bs, -1) # (bs, c, h, w) -> (bs, -1)
        target = -((x_q - x)/(sigmas**2)).view(bs, -1) # (bs, c, h ,w) -> (bs, -1)

        loss = (((score - target) ** 2).sum(dim=-1) * sigmas.squeeze()**2.) / 2
        return loss.mean()

    def forward(self, x):
        assert x.size(-1) == self.image_shape[-1] and x.size(-2) == self.image_shape[-2] and x.size(-3) == self.image_shape[-3], "Different Image Shape"
        labels = torch.randint(0, len(self.sigmas), (x.size(0),), device=x.device) # labels... sigma에 대응하는 값으로.. 어차피 Unet에서 pos embedding으로 임베딩된다.

        return self.ncsn_loss(x, labels)

    @torch.no_grad()
    def langevin_loop(self, x, l, alpha):
        noise = torch.randn(*x.size(), device=x.device)
        score = self.score_func(x, l)

        return x + alpha/2 * score + torch.sqrt(alpha)*noise

    
    @torch.no_grad()
    def annealed_langevin_func(self, x, l): # l : now label
        bs = x.size(0)

        l = torch.full((bs,), l, device=x.device, dtype=torch.long) # l = (*[l] * bs) 이런식으로

        alpha = extract(self.alphas, l).view(-1, 1, 1, 1) * self.step_lr

        for _ in range(self.sampling_steps):
            x = self.langevin_loop(x, l, alpha)
        
        return x
        
    @torch.no_grad()
    def sampling(self, bs):
        imgs = []
        print("# Sampling Start")
        x = torch.randn(bs, *self.image_shape, device=self.sigmas.device)
        for l in range(len(self.alphas)):
            x = self.annealed_langevin_func(x, l)
            imgs.append(torch.clamp(x, 0.0, 1.0).to('cpu'))
        print("# Sampling End")
        return imgs
        

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    img = torch.randn(4, 3, 64, 64).to(device)
    scorenet = Unet(dim = 64, dim_mults = (1, 2, 4, 8))
    model = NCSN(scorenet, sigma_begin=1, sigma_end=0.01, sigma_steps=10, image_shape=(3, 64, 64), sampling_steps=200, step_lr=0.00002).to(device)
    print(model(img).data)
    print(model.sampling(1))

In [None]:
from tqdm import tqdm
import copy
from torchvision.datasets import CIFAR10, MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.utils as utils

def cycle(dl):
    while True:
        for data in dl:
            yield data

transform = transforms.Compose([
                transforms.Resize(64),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])

dataset = MNIST(root="./data", train=True, download=True, transform=transform)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

dl = cycle(dataloader)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scorenet = Unet(dim = 64, dim_mults = (1, 2, 4, 8), channels = 1)
model = NCSN(scorenet, sigma_begin=1, sigma_end=0.01, sigma_steps=10, image_shape=(1, 64, 64), sampling_steps=200, step_lr=0.00002).to(device)
optim = torch.optim.Adam(params=model.parameters(), lr=1e-4)
sampling_term = 1000

# iteration 단위라서 사실 몇번 돌았냐 뿐만 아니라 얼마나 많은 image를 봤느냐도 중요함.
iteration_nums = 50000 
losses = []

for i in tqdm(range(1, iteration_nums+1)) :
    x, y = next(dl)
    x = x.to(device)
    x = x / 256. * 255. + torch.rand_like(x) / 256. # 이미지 noise로 흔들때 이미지 크기는 0~1임. 그래서 noise도 작게 scaling
    
    out = model(x)
    optim.zero_grad()
    out.backward()
    optim.step()
    
    with torch.no_grad():
        if i%sampling_term == 0:
            model.eval()
            losses.append(out.data)
            print(f"Loss : {out.data}")
            im = model.sampling(1)
            
            utils.save_image(torch.cat(im, dim=0), f"sample_{i}.png")
            model.train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  cpuset_checked))
  2%|▏         | 999/50000 [05:06<4:14:39,  3.21it/s]

Loss : 190.40484619140625
# Sampling Start


  2%|▏         | 1000/50000 [05:41<146:19:23, 10.75s/it]

# Sampling End


  4%|▍         | 1999/50000 [10:51<4:07:37,  3.23it/s]

Loss : 132.60789489746094
# Sampling Start


  4%|▍         | 2000/50000 [11:26<144:18:43, 10.82s/it]

# Sampling End


  6%|▌         | 2999/50000 [16:36<4:03:27,  3.22it/s]

Loss : 89.14021301269531
# Sampling Start


  6%|▌         | 3000/50000 [17:13<145:40:11, 11.16s/it]

# Sampling End


  8%|▊         | 3999/50000 [22:23<3:57:51,  3.22it/s]

Loss : 84.02326202392578
# Sampling Start


  8%|▊         | 4000/50000 [23:01<146:38:14, 11.48s/it]

# Sampling End


 10%|▉         | 4999/50000 [28:11<3:51:48,  3.24it/s]

Loss : 83.92194366455078
# Sampling Start


 10%|█         | 5000/50000 [28:48<138:24:29, 11.07s/it]

# Sampling End


 12%|█▏        | 5999/50000 [33:58<3:47:16,  3.23it/s]

Loss : 86.76551055908203
# Sampling Start


 12%|█▏        | 6000/50000 [34:34<133:48:18, 10.95s/it]

# Sampling End


 14%|█▍        | 6999/50000 [39:44<3:42:03,  3.23it/s]

Loss : 73.1364517211914
# Sampling Start


 14%|█▍        | 7000/50000 [40:20<130:16:36, 10.91s/it]

# Sampling End


 16%|█▌        | 7999/50000 [45:31<3:36:43,  3.23it/s]

Loss : 69.85159301757812
# Sampling Start


 16%|█▌        | 8000/50000 [46:07<128:43:22, 11.03s/it]

# Sampling End


 18%|█▊        | 8999/50000 [51:17<3:31:54,  3.22it/s]

Loss : 63.29292297363281
# Sampling Start


 18%|█▊        | 9000/50000 [51:53<124:15:04, 10.91s/it]

# Sampling End


 20%|█▉        | 9999/50000 [57:03<3:26:15,  3.23it/s]

Loss : 66.36909484863281
# Sampling Start


 20%|██        | 10000/50000 [57:39<121:02:39, 10.89s/it]

# Sampling End


 22%|██▏       | 10999/50000 [1:02:49<3:22:11,  3.21it/s]

Loss : 54.58304214477539
# Sampling Start


 22%|██▏       | 11000/50000 [1:03:25<117:52:08, 10.88s/it]

# Sampling End


 24%|██▍       | 11999/50000 [1:08:36<3:16:52,  3.22it/s]

Loss : 43.473731994628906
# Sampling Start


 24%|██▍       | 12000/50000 [1:09:12<114:50:23, 10.88s/it]

# Sampling End


 24%|██▍       | 12238/50000 [1:10:26<3:14:57,  3.23it/s]