In [2]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *

# NET

In [3]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):
        """条件化的U-Net架构，支持时间步和上下文特征注入
        Args:
            in_channels: 输入图像的通道数 (如灰度图为1，RGB为3)
            n_feat: 基础特征通道数 (默认256)
            n_cfeat: 上下文特征的维度 (默认10)
            height: 输入图像的高度 (必须能被4整除，如28/24/20/16等)
        """
        super(ContextUnet, self).__init__()

        # 初始化基础参数
        self.in_channels = in_channels  # 输入通道数
        self.n_feat = n_feat            # 特征通道基数
        self.n_cfeat = n_cfeat          # 上下文特征维度
        self.h = height                 # 图像尺寸（假设h == w）

        # 初始卷积块（含残差连接）
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        # 下采样路径（两个层级）
        self.down1 = UnetDown(n_feat, n_feat)     # 下采样1：[B, 256, H/2, W/2]
        self.down2 = UnetDown(n_feat, 2 * n_feat) # 下采样2：[B, 512, H/4, W/4]

        # 特征向量化层（替代原设计的AvgPool2d(7)）
        self.to_vec = nn.Sequential(
            nn.AvgPool2d((4)),   # 将4x4特征图池化为1x1 
            nn.GELU()            # 高斯误差线性单元激活
        )

        # 时间步与上下文特征嵌入层
        self.timeembed1 = EmbedFC(1, 2*n_feat)       # 时间嵌入1（高维：2*n_feat）
        self.timeembed2 = EmbedFC(1, 1*n_feat)       # 时间嵌入2（低维：1*n_feat）
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat) # 上下文嵌入1（高维）
        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat) # 上下文嵌入2（低维）

        # 上采样路径（三个层级）
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2*n_feat, 2*n_feat, self.h//4, self.h//4), # 转置卷积上采样4倍
            nn.GroupNorm(8, 2*n_feat),  # 分组归一化（8组）                        
            nn.ReLU(),                   # 非线性激活
        )
        self.up1 = UnetUp(4*n_feat, n_feat)  # 上采样1：处理拼接后的4*n_feat通道
        self.up2 = UnetUp(2*n_feat, n_feat)  # 上采样2：处理拼接后的2*n_feat通道

        # 最终输出层
        self.out = nn.Sequential(
            nn.Conv2d(2*n_feat, n_feat, 3, padding=1),  # 降维卷积（3x3核）
            nn.GroupNorm(8, n_feat),     # 分组归一化
            nn.ReLU(),                    # 非线性激活
            nn.Conv2d(n_feat, self.in_channels, 3, padding=1), # 输出层（与输入通道对齐）
        )

    def forward(self, x, t, c=None):
        """
        前向传播过程
        Args:
            x : (B, C, H, W)  输入噪声图像
            t : (B, 1)        扩散时间步 
            c : (B, n_cfeat)  上下文标签（可为None）
        Returns:
            out: (B, C, H, W) 预测的去噪图像
        特征维度变化
            输入图像: [B, 1, 28, 28]
            init_conv → [B, 256, 28, 28]
            down1    → [B, 256, 14, 14]
            down2    → [B, 512, 7, 7]
            to_vec   → [B, 512, 1, 1]
            up0      → [B, 512, 4, 4]
            up1      → [B, 256, 14, 14] 
            up2      → [B, 256, 28, 28]
            输出     → [B, 1, 28, 28]
        """
        # 初始化上下文（若未提供）
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
            
        # 初始卷积
        x = self.init_conv(x)          # [B, 256, 28, 28]
        
        # 下采样路径
        down1 = self.down1(x)          # [B, 256, 14, 14]
        down2 = self.down2(down1)      # [B, 512, 7, 7]
        
        # 特征向量化
        hiddenvec = self.to_vec(down2) # [B, 512, 1, 1]
        
        # 生成条件嵌入（调整形状为[B, C, 1, 1]）
        cemb1 = self.contextembed1(c).view(-1, 2*self.n_feat, 1, 1)  # 高维上下文嵌入
        temb1 = self.timeembed1(t).view(-1, 2*self.n_feat, 1, 1)     # 高维时间嵌入
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)    # 低维上下文嵌入
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)       # 低维时间嵌入

        # 上采样路径（带条件注入）
        up1 = self.up0(hiddenvec)                     # [B, 512, 4, 4] 初始上采样
        up2 = self.up1(cemb1*up1 + temb1, down2)      # 高维条件融合 + 跳跃连接 [B, 256, 14, 14]
        up3 = self.up2(cemb2*up2 + temb2, down1)      # 低维条件融合 + 跳跃连接 [B, 256, 28, 28]
        
        # 最终输出（拼接原始输入）
        out = self.out(torch.cat((up3, x), 1))        # [B, C, 28, 28]
        return out


# hyperparameters

In [None]:

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'

# training hyperparameters
batch_size = 100
n_epoch = 32
lrate=1e-3

In [None]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t

ab_t = torch.cumsum(a_t.log(), dim=0).exp()    # 累乘
ab_t[0] = 1

In [7]:
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)

# Training

In [None]:
# load dataset and construct optimizer
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [10]:
# 加噪公式
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise

In [None]:
# training without context code

# set into train mode
nn_model.train()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    
    pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 
        x_pert = perturb_input(x, t, noise)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
        
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        
        optim.step()

    # save model periodically
    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
        print('saved model at ' + save_dir + f"model_{ep}.pth")

# Sampling

In [None]:
# 去噪公式
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

In [11]:
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):
    # x_T ~ N(0, 1), 随机生产初始噪声x_T
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # 存储中间结果
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate ==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [None]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())