VAE 图片编码 解码模型

残差连接层

In [1]:
import torch


class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out): # 定义一个残差块，输入通道数 dim_in，输出通道数 dim_out。
        super().__init__()

        """
        这是主分支（residual branch）s：
        GroupNorm(32, dim_in): 对每个样本、每组通道做归一化（不依赖 batch 大小），affine=True 说明有可学习的缩放/平移参数；eps=1e-6 防止除零。
        SiLU()：激活函数（Swish：x*sigmoid(x)），比 ReLU 平滑。
        Conv2d(dim_in→dim_out, 3×3, padding=1): 保持空间大小不变，通道变为 dim_out。
        GroupNorm(32, dim_out): 归一化新通道数。
        SiLU()：再次激活。
        Conv2d(dim_out→dim_out, 3×3, padding=1): 通道与空间尺寸都保持（通道仍 dim_out）。
        """
        self.s = torch.nn.Sequential(
            torch.nn.GroupNorm(num_groups=32,
                               num_channels=dim_in,
                               eps=1e-6,
                               affine=True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_in,
                            dim_out,
                            kernel_size=3,
                            stride=1,
                            padding=1),
            torch.nn.GroupNorm(num_groups=32,
                               num_channels=dim_out,
                               eps=1e-6,
                               affine=True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_out,
                            dim_out,
                            kernel_size=3,
                            stride=1,
                            padding=1),
        )

        """
        这是捷径分支（skip/shortcut）：只有当输入输出通道不同才用 1×1 卷积把通道从 dim_in 投影到 dim_out，以便后续逐元素相加时形状一致。
        如果 dim_in == dim_out，捷径就是恒等映射（不需要卷积）。
        """
        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in,
                                       dim_out,
                                       kernel_size=1,
                                       stride=1,
                                       padding=0)

    def forward(self, x):
        #x -> [1, 128, 10, 10]

        """
        先保存一份输入作为捷径分支。
        因为这里 dim_in=128 != 256=dim_out，self.res 存在，于是用 1×1 卷积把通道变到 256，空间尺寸不变：
        res 形状变为 [1, 256, 10, 10]。
        """
        res = x
        if self.res:
            #[1, 128, 10, 10] -> [1, 256, 10, 10]
            res = self.res(x)

        #[1, 128, 10, 10] -> [1, 256, 10, 10]
        return res + self.s(x)


Resnet(128, 256)(torch.randn(1, 128, 10, 10)).shape

torch.Size([1, 256, 10, 10])

VAE 注意力层 初始化部分  典型的自注意力  计算过程q*k*v  单头，无mask

In [2]:
class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.norm = torch.nn.GroupNorm(num_channels=512,
                                       num_groups=32,
                                       eps=1e-6,
                                       affine=True) # 对每个样本按通道分 32 组做归一化（每组 16 个通道），提升稳定性；affine=True 会学到每组的缩放/平移参数。

        self.q = torch.nn.Linear(512, 512)
        self.k = torch.nn.Linear(512, 512)
        self.v = torch.nn.Linear(512, 512)
        self.out = torch.nn.Linear(512, 512) # 把每个“空间位置”的 512 维特征映射到查询/键/值/输出空间（这里等维，单头注意力的实现方式）。

    def forward(self, x):
        #x -> [1, 512, 64, 64]
        res = x # 保存残差，最后做 x + res。

        #norm,维度不变
        #[1, 512, 64, 64]
        x = self.norm(x) # 按组做标准化，缓解分布漂移、加速收敛。

        #[1, 512, 64, 64] -> [1, 512, 4096] -> [1, 4096, 512]
        x = x.flatten(start_dim=2).transpose(1, 2) # 把 64×64 个像素位置拼成长度为 4096 的“token 序列”，每个 token 的通道维是 512。
                                                   # 此时：L=4096（序列长度），D=512（通道/嵌入维）。

        #线性运算,维度不变 线性投影得到 Q/K/V（逐 token 线性变换）
        #[1, 4096, 512]
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        #[1, 4096, 512] -> [1, 512, 4096] 转置 K 以便做 batched 矩阵乘（K^T）
        k = k.transpose(1, 2)

        #[1, 4096, 512] * [1, 512, 4096] -> [1, 4096, 4096]
        #0.044194173824159216 = 1 / 512**0.5
        #atten = q.bmm(k) * 0.044194173824159216

        # 期望：attn_logits = q @ k * (1 / sqrt(D))
        # 这里用 baddbmm 实现：C = beta*C + alpha*(q @ k)

        #照理来说应该是等价的,但是却有很小的误差
        atten = torch.baddbmm(torch.empty(1, 4096, 4096, device=q.device), # 初始 C（因为 beta=0，所以内容会被忽略）
                              q, # [1, 4096, 512]
                              k, # [1, 512, 4096]
                              beta=0,
                              alpha=0.044194173824159216) # ≈ 1/sqrt(512)

        atten = torch.softmax(atten, dim=2) # 对“keys 维”（最后一维）做 softmax：每个 query 的 4096 个权重求和为 1。

        #[1, 4096, 4096] * [1, 4096, 512] -> [1, 4096, 512]
        atten = atten.bmm(v)

        #线性运算,维度不变
        #[1, 4096, 512]
        atten = self.out(atten)

        #[1, 4096, 512] -> [1, 512, 4096] -> [1, 512, 64, 64]
        atten = atten.transpose(1, 2).reshape(-1, 512, 64, 64)

        #残差连接,维度不变
        #[1, 512, 64, 64]
        atten = atten + res

        return atten


Atten()(torch.randn(1, 512, 64, 64)).shape

torch.Size([1, 512, 64, 64])

工具层 增加一行一列0

F.pad(x, (0, 1, 0, 1), mode='constant', value=0) 对 4D 张量 [N, C, H, W] 的 宽、高 维做填充：

参数顺序（2D）是 (left, right, top, bottom)

这里是 (0, 1, 0, 1)：左 0、右 +1、上 0、下 +1

结果：[1, 2, 5, 5] → [1, 2, 6, 6]

因为输入是全 1，填充后最后一列和最后一行是 0，其余仍为 1。

In [4]:
class Pad(torch.nn.Module):

    def forward(self, x):
        return torch.nn.functional.pad(x, (0, 1, 0, 1),
                                       mode='constant',
                                       value=0)


Pad()(torch.ones(1, 2, 5, 5))

tensor([[[[1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0.]],

         [[1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0.]]]])

VAE模型  计算过程  编码，投影，解码

In [5]:
class VAE(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            #in
            torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),

            #down
            torch.nn.Sequential(
                Resnet(128, 128),
                Resnet(128, 128),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(128, 256),
                Resnet(256, 256),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(256, 256, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(256, 512),
                Resnet(512, 512),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(512, 512, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
            ),

            #mid
            torch.nn.Sequential(
                Resnet(512, 512),
                Atten(),
                Resnet(512, 512),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(512, 8, 3, padding=1),
            ),

            #正态分布层
            torch.nn.Conv2d(8, 8, 1),
        )

        self.decoder = torch.nn.Sequential(
            #正态分布层
            torch.nn.Conv2d(4, 4, 1),

            #in
            torch.nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1),

            #middle
            torch.nn.Sequential(Resnet(512, 512), Atten(), Resnet(512, 512)),

            #up
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
                Resnet(512, 512),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
                Resnet(512, 512),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(512, 256),
                Resnet(256, 256),
                Resnet(256, 256),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(256, 128),
                Resnet(128, 128),
                Resnet(128, 128),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(128, 3, 3, padding=1),
            ),
        )

    def sample(self, h):
        # h -> [1, 8, 64, 64]  # h: [B=1, 8, 64, 64]，每个空间位置都有 8 维值：前 4 维是 mean，后 4 维是 logvar

        #[1, 4, 64, 64]
        mean = h[:, :4] # 取前 4 个通道当作 μ，形状 [1, 4, 64, 64]
        logvar = h[:, 4:]  # 取后 4 个通道当作 log σ²，形状 [1, 4, 64, 64]
        std = logvar.exp()**0.5 # 先做 exp 得到 σ²，再开方得到 σ（数值等价于 torch.exp(0.5*logvar)）
                                # 形状 [1, 4, 64, 64]

        #[1, 4, 64, 64]
        h = torch.randn(mean.shape, device=mean.device) # 采样 ε ~ N(0, I)，形状 [1, 4, 64, 64]
        h = mean + std * h # 重参数化：z = μ + σ ⊙ ε，梯度可穿过 μ 和 σ

        return h # 返回 z，形状 [1, 4, 64, 64]

    def forward(self, x):
        #x -> [1, 3, 512, 512]

        #[1, 3, 512, 512] -> [1, 8, 64, 64]
        h = self.encoder(x) # 编码：输出 [1, 8, 64, 64]
                            # 按你的网络设计：这 8 个通道 = 4(μ) + 4(log σ²)

        #[1, 8, 64, 64] -> [1, 4, 64, 64]
        h = self.sample(h) # 采样潜变量：把 [1, 8, 64, 64] -> [1, 4, 64, 64]（z）

        #[1, 4, 64, 64] -> [1, 3, 512, 512]
        h = self.decoder(h) # 解码：将 z 还原到图像空间，得到 [1, 3, 512, 512]

        return h # 返回重建图像，形状与输入一致


VAE()(torch.randn(1, 3, 512, 512)).shape

torch.Size([1, 3, 512, 512])

准备加载  预训练参数

In [6]:
from diffusers import AutoencoderKL

#加载预训练模型的参数
params = AutoencoderKL.from_pretrained(
    'lansinuote/diffsion_from_scratch.params', subfolder='vae')

vae = VAE()


def load_res(model, param):
    model.s[0].load_state_dict(param.norm1.state_dict())
    model.s[2].load_state_dict(param.conv1.state_dict())
    model.s[3].load_state_dict(param.norm2.state_dict())
    model.s[5].load_state_dict(param.conv2.state_dict())

    if isinstance(model.res, torch.nn.Module):
        model.res.load_state_dict(param.conv_shortcut.state_dict())


def conv1x1_to_linear_(linear: torch.nn.Linear, conv: torch.nn.Conv2d):
    """
    将 1x1 Conv2d 权重拷贝到 Linear：
    [out, in, 1, 1] -> [out, in]
    """
    with torch.no_grad():
        w = conv.weight.data
        if w.ndim == 4:
            w = w.squeeze(-1).squeeze(-1)      # [O, I, 1, 1] -> [O, I]
        linear.weight.copy_(w)
        # 只有两边都有 bias 才拷贝
        if (linear.bias is not None) and (conv.bias is not None):
            linear.bias.copy_(conv.bias.data)

def load_atten(model, param):
    """
    model: 你的 Atten 实例（q/k/v/out 是 Linear）
    param: diffusers 的 Attention 模块（query/key/value/proj_attn 是 Conv2d 1x1）
    """
    model.norm.load_state_dict(param.group_norm.state_dict())
    conv1x1_to_linear_(model.q,   param.query)
    conv1x1_to_linear_(model.k,   param.key)
    conv1x1_to_linear_(model.v,   param.value)
    conv1x1_to_linear_(model.out, param.proj_attn)


#encoder.in
vae.encoder[0].load_state_dict(params.encoder.conv_in.state_dict())

#encoder.down
for i in range(4):
    load_res(vae.encoder[i + 1][0], params.encoder.down_blocks[i].resnets[0])
    load_res(vae.encoder[i + 1][1], params.encoder.down_blocks[i].resnets[1])

    if i != 3:
        vae.encoder[i + 1][2][1].load_state_dict(
            params.encoder.down_blocks[i].downsamplers[0].conv.state_dict())

#encoder.mid
load_res(vae.encoder[5][0], params.encoder.mid_block.resnets[0])
load_res(vae.encoder[5][2], params.encoder.mid_block.resnets[1])
load_atten(vae.encoder[5][1], params.encoder.mid_block.attentions[0])

#encoder.out
vae.encoder[6][0].load_state_dict(params.encoder.conv_norm_out.state_dict())
vae.encoder[6][2].load_state_dict(params.encoder.conv_out.state_dict())

#encoder.正态分布层
vae.encoder[7].load_state_dict(params.quant_conv.state_dict())

#decoder.正态分布层
vae.decoder[0].load_state_dict(params.post_quant_conv.state_dict())

#decoder.in
vae.decoder[1].load_state_dict(params.decoder.conv_in.state_dict())

#decoder.mid
load_res(vae.decoder[2][0], params.decoder.mid_block.resnets[0])
load_res(vae.decoder[2][2], params.decoder.mid_block.resnets[1])
load_atten(vae.decoder[2][1], params.decoder.mid_block.attentions[0])

#decoder.up
for i in range(4):
    load_res(vae.decoder[i + 3][0], params.decoder.up_blocks[i].resnets[0])
    load_res(vae.decoder[i + 3][1], params.decoder.up_blocks[i].resnets[1])
    load_res(vae.decoder[i + 3][2], params.decoder.up_blocks[i].resnets[2])

    if i != 3:
        vae.decoder[i + 3][4].load_state_dict(
            params.decoder.up_blocks[i].upsamplers[0].conv.state_dict())

#decoder.out
vae.decoder[7][0].load_state_dict(params.decoder.conv_norm_out.state_dict())
vae.decoder[7][2].load_state_dict(params.decoder.conv_out.state_dict())

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


Downloading:   0%|          | 0.00/335M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading:   0%|          | 0.00/728 [00:00<?, ?B/s]

The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.


<All keys matched successfully>

检查模型  构建准确无误

In [9]:
data = torch.randn(1, 3, 512, 512)

a = vae.encoder(data)
b = params.encode(data).latent_dist.parameters

(a == b).all()

tensor(True)

In [10]:
data = torch.randn(1, 4, 64, 64)

a = vae.decoder(data)
b = params.decode(data).sample

(a == b).all()

tensor(True)