# **一个Diffusion Model算法的简单实现**
## **算法概述**
- AI绘画模型使用的基础算法
- 分为前向过程和后向过程两个阶段
    - 前向过程为向原图片中逐步添加高斯噪声，最终原图片成为完全高斯噪声
    - 后向过程为从完全的高斯噪声恢复为原图片的过程（生成特指后向过程）
- 模型的主体架构为一个U-net网络，在下采样和上采样的对应层之间加入了注意力模块
- 模型的任务是预测在前向过程中添加的噪声

论文链接：*https://docs.popo.netease.com/ofedit/ba6c1ddd73f242e89c704f37244cc083*

In [1]:
import math
from inspect import isfunction
from functools import partial

from tqdm.auto import tqdm
from einops import rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

device = 0 if torch.cuda.is_available else "cpu"
device = "cpu"
# dataset_name = "mnist"
dataset_name = "fashion_mnist" # 一个时装数据集，以灰度图的形式保存了各种时装的照片

In [2]:
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

In [3]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
    

def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

class SinusoidaPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, time:torch.Tensor):
        device = time.device
        half_dim = self.dim // 2
        embedding = math.log(10000) / (half_dim - 1)
        embedding = torch.exp(torch.arange(half_dim, device=device) * - embedding)
        embedding = rearrange(time, 'i -> i 1') * rearrange(embedding, 'j -> 1 j')
        # embedding = time[:,None] * embedding[:,None]
        return torch.cat((embedding.sin(), embedding.cos()), dim = -1)
    
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=4):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
        
    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)
        
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
        
        x = self.act(x)
        return x
    
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups=4):
        super().__init__()
        self.mlp = (
        nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim) else None
        )
        
        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
        
    def forward(self, x, time_emb=None):
        h = self.block1(x)
        
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, "b c -> b c 1 1") * h
        
        h = self.block2(h)
        return h + self.res_conv(x)
    
class ConvNextBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
        nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim) else None
        )
        
        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
        
        self.net = nn.Sequential(
            nn.GroupNorm(1,dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )
        
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
        
    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)
        
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time_emb must be passed in"
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")
        
        h = self.net(h)
        return h + self.res_conv(x)
    

In [4]:
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim *3 ,1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
        
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t:rearrange(t, "b (h c) x y -> b h c (x y)",h=self.heads), qkv
        )
        q = q * self.scale
        
        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        attn = sim.softmax(dim=-1)
        
        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim *3 ,1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                   nn.GroupNorm(1, dim))
        
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t:rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        
        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)
        
        q = q * self.scale
        context = einsum("b h d n, b h e n -> b h d e", k ,v)
        
        out = einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)
    
    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

In [5]:
class Unet(nn.Module):
    def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8),
                channels=3, with_time_emb=True, resnet_block_groups=4, use_convnext=False, convnext_mult=2):
        super().__init__()
        
        self.channels = channels
        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
        
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)
        
        
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidaPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None
        
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolution = len(in_out)
        
        for idx, (dim_in, dim_out) in enumerate(in_out):
            is_last = idx >= (num_resolution - 1)
            
            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )
            
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        
        for idx, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = idx >= num_resolution - 1
            
            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )
            
        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim),
            nn.Conv2d(dim, out_dim, 1)
        )
    
    def forward(self, x, time):
        x = self.init_conv(x)
        
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        
        h = []
        
        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)
        
        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)
        
        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)
        res = self.final_conv(x)
        return res          

In [6]:
def consine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beata_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beata_start, beta_end, timesteps)

# and more beta_schedule

In [7]:
timesteps = 200

betas =  linear_beta_schedule(timesteps)

alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

sigma_t = torch.cat((torch.tensor([0]), alphas[1:] / betas[1:] + 1 / torch.pow(sqrt_one_minus_alphas_cumprod, 2)[:-1]))
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

In [8]:
from torchvision.transforms import Compose

def q_sample(x_0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_0)
    
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_0.shape)
    
    return sqrt_one_minus_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

def get_noisy_image(x_0, t):
    x_noisy = q_sample(x_0, t) 
    
    noisy_image = reverse_transform(x_noisy.squeeze())
    
    return noisy_image


In [9]:
def p_losses(denoise_model, x_0, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_0)
    
    x_noisy = q_sample(x_0, t, noise)
    predicted_noise = denoise_model(x_noisy, t)
    loss = F.l1_loss(noise, predicted_noise)
    
    return loss

In [10]:
from datasets import load_dataset

dataset = load_dataset(dataset_name)
image_size = 28
channels = 1
batch_size = 128

Found cached dataset fashion_mnist (/home/ygq/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


  0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Lambda(lambda t:(t*2) - 1)
])

def transforms_(examples):
    examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
    del examples["image"]
    return examples

transformed_dataset = dataset.with_transform(transforms_).remove_columns("label")

dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)

In [12]:
import matplotlib.pyplot as plt
def generate_image(images):
    image = images[0,0,:,:]
    plt.imshow(image, cmap='gray')
    plt.show()

In [13]:
from torch.optim import Adam


model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

In [14]:
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def p_sample_loop(model, shape):
    b = shape[0]
    img = torch.randn(shape, device=device)
    imgs = []
    
    for i in reversed(range(0, 200)):
        img = p_sample(model, img, torch.full((b,), i, dtype=torch.long).to(device), i)
        imgs.append(img.cpu().numpy())
    return imgs

def sample(model, image_size, batch_size=64, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

In [None]:
epochs = 20

for epoch in range(epochs):
    model.train()
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        
        batch_size = batch["pixel_values"].shape[0]
        batch = batch["pixel_values"].to(device)
        
        t = torch.randint(0, timesteps, (batch_size,), device = device).long()
        
        loss = p_losses(model, batch, t)
        if step % 100 == 0:
            print(f"epoch {epoch} step {step}, loss is {loss.item()}")
        
        
        loss.backward()
        optimizer.step()
    # generate_image(model)
    sample_list = sample(model, image_size=image_size, batch_size=64,channels=channels)
    generate_image(sample_list[-1])

epoch 0 step 0, loss is 1.0175487995147705
epoch 0 step 100, loss is 0.26050490140914917
epoch 0 step 200, loss is 0.17977216839790344
epoch 0 step 300, loss is 0.1644842028617859
epoch 0 step 400, loss is 0.16392913460731506
