In [1]:
import os

# 设置环境变量
os.environ['HF_HOME'] = './cache'
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

1.Encoder

In [2]:
import torch

class Embed(torch.nn.Module):
    def __init__(self):
        super(Embed, self).__init__()   #在python3中使用super().__init__()即可

        self.embed = torch.nn.Embedding(49408, 768) # 49408是词表大小，768是词向量维度
        self.pos_embed = torch.nn.Embedding(77, 768)    # 77是最大句子长度

        # 注册一个pos_ids张量，这个张量不需要梯度，不需要更新
        self.register_buffer('pos_ids', torch.arange(77).unsqueeze(dim=0))

    def forward(self, input_ids):
        # input_ids: [batch_size, seq_len] ~ [b, 77]
        #这是模型的输入，代表一个批次中的数据，每个样本是长度为77的序列，序列中的每个元素是一个整数索引，代表一个单词。

        # [b, 77] -> [b, 77, 768]
        # self.embed是一个嵌入层，它将input_ids中的每个单词索引映射到一个768维的向量。
        # 因此，对于输入中的每个单词，嵌入层返回一个向量，为整个批次生成一个[b, 77, 768]的张量。
        embed = self.embed(input_ids)  # [batch_size, seq_len, 768]

        # [1, 77] -> [1, 77, 768]
        # self.pos_embed是另一个嵌入层，用于学习位置编码。self.pos_ids是一个预先注册的张量，其中包含从0到76的整数，表示句子中每个位置的索引。
        # 位置嵌入层将这些位置索引映射到768维的向量。
        # 因为pos_ids是一个[1, 77]的张量，所以pos_embeds的结果维度将是[1, 77, 768]。
        pos_embeds = self.pos_embed(self.pos_ids)

        # [b, 77, 768] + [1, 77, 768] -> [b, 77, 768]
        # 最后，将单词嵌入embed和位置嵌入pos_embeds相加。
        # 由于pos_embeds的第一个维度是1，它会在加法操作中广播到输入的批次大小b。
        # 这意味着位置嵌入将被复制到每个样本上，使得每个单词的嵌入都加上相应的位置嵌入。
        # 最终，输出张量的维度仍然是[b, 77, 768]，其中每个单词的嵌入现在包含了关于它在序列中位置的信息。
        return embed + pos_embeds

Embed()(torch.rand(4, 77).long()).shape

torch.Size([4, 77, 768])

In [3]:
class Atten(torch.nn.Module):
    def __init__(self):
        super(Atten, self).__init__()

        self.q = torch.nn.Linear(768,768)
        self.k = torch.nn.Linear(768,768)
        self.v = torch.nn.Linear(768,768)
        self.out = torch.nn.Linear(768,768)

    def forward(self, x):
        # x: [b, 77, 768]
        b = x.shape[0]

        #维度不变
        #[b, 77, 768]
        q = self.q(x) * 0.125   #当它们被分成12个头时，每个头的维度变为64（因为 768 / 12 = 64）。因此，理论上应该是除以 sqrt(64)，等同于除以8。
        k = self.k(x)
        v = self.v(x)

        #拆分注意力头
        # [b, 77, 768] -> [b, 77, 12, 64] -> [b, 12, 77, 64] -> [b*12, 77, 64]
        q = q.reshape(b, 77, 12, 64).transpose(1, 2).reshape(b*12, 77, 64)
        k = k.reshape(b, 77, 12, 64).transpose(1, 2).reshape(b*12, 77, 64)
        v = v.reshape(b, 77, 12, 64).transpose(1, 2).reshape(b*12, 77, 64)

        #注意力得分
        # [b*12, 77, 64] @ [b*12, 64, 77] -> [b*12, 77, 77], @是矩阵乘法
        attn = torch.bmm(q, k.transpose(1, 2))

        # 拆分注意力头
        # [b*12, 77, 77] -> [b, 12, 77, 77]
        attn = attn.reshape(b, 12, 77, 77)

        #覆盖mask
        def get_mask(b):
            mask = torch.empty(b, 77, 77)

            #上三角的部分置为负无穷
            mask.fill_(float('-inf'))

            #对角线及其下三角部分置为0
            mask.triu_(diagonal=1)

            return mask.unsqueeze(1)    #unsqueeze(1)是为了在第二维上增加一个维度,变成[b, 1, 77, 77]

        # [b, 12, 77, 77] + [b, 1, 77, 77] -> [b, 12, 77, 77]
        attn = attn + get_mask(attn.shape[0]).to(attn.device)

        # [b, 12, 77, 77] -> [b*12, 77, 77]
        attn = attn.reshape(b * 12, 77, 77)

        #计算softmax,被mask的部分值为0
        attn = attn.softmax(dim=-1)

        #计算和v的乘积
        #[b*12, 77, 77] @ [b*12, 77, 64] -> [b*12, 77, 64]
        attn = torch.bmm(attn, v)

        #[b*12, 77, 64] -> [b, 12, 77, 64] -> [b, 77, 12, 64] -> [b, 77, 768]
        attn = attn.reshape(b, 12, 77, 64).transpose(1, 2).reshape(b, 77, 768)

        #线性输出，维度不变
        #[b, 77, 768]
        return self.out(attn)

Atten()(torch.rand(2, 77, 768)).shape

torch.Size([2, 77, 768])

In [4]:
class ClipEncoder(torch.nn.Module):
    def __init__(self):
        super(ClipEncoder, self).__init__()

        self.s1 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),    #数据预处理阶段的归一化通常不涉及学习参数，而神经网络中的标准化技术则包含可学习的参数，这些参数是模型训练的一部分。
            Atten(),
        )

        self.s2 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            #如果模型的维度（通常表示为d_model）是768，那么FFN内部层的维度通常会被设置为4 * d_model，即3072。
            # 这个维度的扩大可以给模型带来更多的表达能力，但同时也会增加模型的参数量和计算负担。
            # 选择3072是基于经验的，你可以根据具体任务和计算资源自行调整这个大小。
            torch.nn.Linear(768, 3072),
        )

        self.s3 = torch.nn.Linear(3072, 768)

    def forward(self, x):
        #x: [b, 77, 768]

        #维度不变
        #[b, 77, 768]
        x = x + self.s1(x)

        #[b, 77, 768]
        res = x

        #[b, 77, 768] -> [b, 77, 3072]
        x = self.s2(x)

        #维度不变
        #[b, 77, 3072]
        x = x * (x * 1.702).sigmoid()   #GELU（Gaussian Error Linear Unit）激活函数的变种。

        #[b, 77, 3072] -> [b, 77, 768]
        return res + self.s3(x)

ClipEncoder()(torch.rand(2, 77, 768)).shape

torch.Size([2, 77, 768])

In [5]:
encoder = torch.nn.Sequential(
    Embed(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    torch.nn.LayerNorm(768),
)

encoder(torch.rand(2, 77).long()).shape

torch.Size([2, 77, 768])

In [6]:
print(encoder)

Sequential(
  (0): Embed(
    (embed): Embedding(49408, 768)
    (pos_embed): Embedding(77, 768)
  )
  (1): ClipEncoder(
    (s1): Sequential(
      (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (1): Atten(
        (q): Linear(in_features=768, out_features=768, bias=True)
        (k): Linear(in_features=768, out_features=768, bias=True)
        (v): Linear(in_features=768, out_features=768, bias=True)
        (out): Linear(in_features=768, out_features=768, bias=True)
      )
    )
    (s2): Sequential(
      (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=768, out_features=3072, bias=True)
    )
    (s3): Linear(in_features=3072, out_features=768, bias=True)
  )
  (2): ClipEncoder(
    (s1): Sequential(
      (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (1): Atten(
        (q): Linear(in_features=768, out_features=768, bias=True)
        (k): Linear(in_features=768, out_features=768, bias=True)
        (v)

In [7]:
from transformers import CLIPTextModel

# 加载预训练模型的参数
params = CLIPTextModel.from_pretrained(
    "./Diffusion_model", subfolder="text_encoder"
)

#词编码
encoder[0].embed.load_state_dict(params.text_model.embeddings.token_embedding.state_dict())

#位置编码
encoder[0].pos_embed.load_state_dict(params.text_model.embeddings.position_embedding.state_dict())

#12层编码层
for i in range(12):

    #第一层norm
    encoder[i + 1].s1[0].load_state_dict(params.text_model.encoder.layers[i].layer_norm1.state_dict())

    #注意力q矩阵
    encoder[i + 1].s1[1].q.load_state_dict(params.text_model.encoder.layers[i].self_attn.q_proj.state_dict())

    #注意力k矩阵
    encoder[i + 1].s1[1].k.load_state_dict(params.text_model.encoder.layers[i].self_attn.k_proj.state_dict())

    #注意力v矩阵
    encoder[i + 1].s1[1].v.load_state_dict(params.text_model.encoder.layers[i].self_attn.v_proj.state_dict())

    #注意力out
    encoder[i + 1].s1[1].out.load_state_dict(params.text_model.encoder.layers[i].self_attn.out_proj.state_dict())

    #第二层norm
    encoder[i + 1].s2[0].load_state_dict(params.text_model.encoder.layers[i].layer_norm2.state_dict())

    #mlp第一层fc
    encoder[i + 1].s2[1].load_state_dict(params.text_model.encoder.layers[i].mlp.fc1.state_dict())

    #mlp第二层fc
    encoder[i + 1].s3.load_state_dict(params.text_model.encoder.layers[i].mlp.fc2.state_dict())

encoder[13].load_state_dict(params.text_model.final_layer_norm.state_dict())

<All keys matched successfully>

In [8]:
#检查模型构建正确无误
a = encoder(torch.arange(77).unsqueeze(0))
b = params(torch.arange(77).unsqueeze(0)).last_hidden_state

(a == b).all()

tensor(True)

2.VAE

In [9]:
class Resnet(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.s = torch.nn.Sequential(
            torch.nn.GroupNorm(
                num_groups=32, num_channels=dim_in, eps=1e-6, affine=True
            ),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1),
            torch.nn.GroupNorm(
                num_groups=32, num_channels=dim_out, eps=1e-6, affine=True
            ),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1),
        )

        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        #x: [1, 128, 10, 10]

        res = x
        if self.res:
            #[1, 128, 10, 10] -> [1, 256, 10, 10]
            res = self.res(x)

        #[1, 128, 10, 10] -> [1, 256, 10, 10]
        return res + self.s(x)

Resnet(128, 256)(torch.randn(1, 128, 10, 10)).shape


torch.Size([1, 256, 10, 10])

In [10]:
class Atten(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = torch.nn.GroupNorm(
            num_channels=512, num_groups=32, eps=1e-6, affine=True
        )

        self.q = torch.nn.Linear(512, 512)
        self.k = torch.nn.Linear(512, 512)
        self.v = torch.nn.Linear(512, 512)
        self.out = torch.nn.Linear(512, 512)

    def forward(self, x):
        # x: [1, 512, 64, 64]
        res = x

        # norm,维度不变
        # [1, 512, 64, 64]
        x = self.norm(x)

        # [1, 512, 64, 64] -> [1, 512, 4096] -> [1, 4096, 512]
        x = x.flatten(2).transpose(1, 2)  # flatten(2)是将后两维展平，transpose(1, 2)是交换第二维和第三维

        # 线性运算，维度不变
        # [1, 4096, 512]
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        # [1, 4096, 512] -> [1, 512, 4096]
        k = k.transpose(1, 2)

        # [1, 4096, 512] * [1, 512, 4096] -> [1, 4096, 4096]
        atten = torch.baddbmm(
            torch.empty(1, 4096, 4096, device=q.device), q, k, beta=0, alpha=(512**-0.5)
        )

        atten = torch.softmax(atten, dim=2)

        #[1, 4096, 4096] * [1, 4096, 512] -> [1, 4096, 512]
        atten = atten.bmm(v)

        #线性运算，维度不变
        # [1, 4096, 512]
        atten = self.out(atten)

        # [1, 4096, 512] -> [1, 512, 4096] -> [1, 512, 64, 64]
        atten = atten.transpose(1, 2).view(1, 512, 64, 64)

        #残差连接，维度不变
        # [1, 512, 64, 64]
        atten = atten + res

        return atten

Atten()(torch.randn(1, 512, 64, 64)).shape

torch.Size([1, 512, 64, 64])

In [11]:
class Pad(torch.nn.Module):
    def forward(self, x):
        # x: [1, 512, 64, 64]
        return torch.nn.functional.pad(x, (0, 1, 0, 1), mode="constant", value=0)   #pad函数的参数是左右上下的填充数,modes是填充的方式，value是填充的值

Pad()(torch.ones(1, 2, 5, 5))

tensor([[[[1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0.]],

         [[1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0., 0.]]]])

In [12]:
class VAE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            #in
            torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),

            #down
            torch.nn.Sequential(
                Resnet(128, 128),
                Resnet(128, 128),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(128, 256),
                Resnet(256, 256),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(256, 256, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(256, 512),
                Resnet(512, 512),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(512, 512, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
            ),

            #mid
            torch.nn.Sequential(
                Resnet(512, 512),
                Atten(),
                Resnet(512, 512),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(512, 8, 3, padding=1),
            ),

            #正态分布层
            torch.nn.Conv2d(8, 8, 1),
        )

        self.decoder = torch.nn.Sequential(
            #正态分布层
            torch.nn.Conv2d(4, 4, 1),

            #in
            torch.nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1),

            #middle
            torch.nn.Sequential(Resnet(512, 512), Atten(), Resnet(512, 512)),

            #up
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
                Resnet(512, 512),
                torch.nn.Upsample(scale_factor=2, mode="nearest"),
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
                Resnet(512, 512),
                torch.nn.Upsample(scale_factor=2, mode="nearest"),
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(512, 256),
                Resnet(256, 256),
                Resnet(256, 256),
                torch.nn.Upsample(scale_factor=2, mode="nearest"),
                torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(256, 128),
                Resnet(128, 128),
                Resnet(128, 128),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(128, 3, 3, padding=1),
            ),
        )

    def sample(self, h):
        #h: [1, 8, 64, 64]

        #[1, 4, 64, 64]
        mean = h[:, :4]
        logvar = h[:, 4:]
        std = logvar.exp()**0.5

        #[1, 4, 64, 64]
        h = torch.randn(mean.shape, device=mean.device)
        h = mean + std * h

        return h

    def forward(self, x):
        #x: [1, 3, 512, 512]

        #[1, 3, 512, 512] -> [1, 8, 64, 64]
        h = self.encoder(x)

        #[1, 8, 64, 64] -> [1, 4, 64, 64]
        h = self.sample(h)

        #[1, 4, 64, 64] -> [1, 3, 512, 512]
        h = self.decoder(h)

        return h

VAE()(torch.randn(1, 3, 512, 512)).shape


torch.Size([1, 3, 512, 512])

In [13]:
from diffusers import AutoencoderKL

#加载预训练模型的参数
params = AutoencoderKL.from_pretrained("./Diffusion_model", subfolder="vae")

vae = VAE()

def load_res(model, param):
    model.s[0].load_state_dict(param.norm1.state_dict())
    model.s[2].load_state_dict(param.conv1.state_dict())
    model.s[3].load_state_dict(param.norm2.state_dict())
    model.s[5].load_state_dict(param.conv2.state_dict())

    if isinstance(model.res, torch.nn.Module):
        model.res.load_state_dict(param.conv_shortcut.state_dict())

def load_attn(model, param):
    model.norm.load_state_dict(param.group_norm.state_dict())
    model.q.load_state_dict(param.to_q.state_dict())
    model.k.load_state_dict(param.to_k.state_dict())
    model.v.load_state_dict(param.to_v.state_dict())
    model.out.load_state_dict(param.to_out[0].state_dict())

#encoder.in
vae.encoder[0].load_state_dict(params.encoder.conv_in.state_dict())

#encoder.down
for i in range(4):
    load_res(vae.encoder[i + 1][0], params.encoder.down_blocks[i].resnets[0])
    load_res(vae.encoder[i + 1][1], params.encoder.down_blocks[i].resnets[1])

    if i != 3:
        vae.encoder[i + 1][2][1].load_state_dict(params.encoder.down_blocks[i].downsamplers[0].conv.state_dict())

#encoder.mid
load_res(vae.encoder[5][0], params.encoder.mid_block.resnets[0])
load_res(vae.encoder[5][2], params.encoder.mid_block.resnets[1])
load_attn(vae.encoder[5][1], params.encoder.mid_block.attentions[0])

#encoder.out
vae.encoder[6][0].load_state_dict(params.encoder.conv_norm_out.state_dict())
vae.encoder[6][2].load_state_dict(params.encoder.conv_out.state_dict())

#encoder.正态分布层
vae.encoder[7].load_state_dict(params.quant_conv.state_dict())

#decoder.正态分布层
vae.decoder[0].load_state_dict(params.post_quant_conv.state_dict())

#decoder.in
vae.decoder[1].load_state_dict(params.decoder.conv_in.state_dict())

#decoder.mid
load_res(vae.decoder[2][0], params.decoder.mid_block.resnets[0])
load_res(vae.decoder[2][2], params.decoder.mid_block.resnets[1])
load_attn(vae.decoder[2][1], params.decoder.mid_block.attentions[0])

#decoder.up
for i in range(4):
    load_res(vae.decoder[i + 3][0], params.decoder.up_blocks[i].resnets[0])
    load_res(vae.decoder[i + 3][1], params.decoder.up_blocks[i].resnets[1])
    load_res(vae.decoder[i + 3][2], params.decoder.up_blocks[i].resnets[2])

    if i != 3:
        vae.decoder[i + 3][4].load_state_dict(
            params.decoder.up_blocks[i].upsamplers[0].conv.state_dict())

#decoder.out
vae.decoder[7][0].load_state_dict(params.decoder.conv_norm_out.state_dict())
vae.decoder[7][2].load_state_dict(params.decoder.conv_out.state_dict())

    PyTorch 1.13.1+cu117 with CUDA 1107 (you have 1.13.1+cpu)
    Python  3.9.13 (you have 3.9.18)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


<All keys matched successfully>

In [14]:
data = torch.randn(1, 3, 512, 512)

a = vae.encoder(data)
b = params.encode(data).latent_dist.parameters

(a == b).all()

tensor(True)

In [15]:
data = torch.randn(1, 4, 64, 64)

a = vae.decoder(data)
b = params.decode(data).sample

(a == b).all()

tensor(True)

3.Unet

In [16]:
class Resnet(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.time = torch.nn.Sequential(
            torch.nn.SiLU(),
            torch.nn.Linear(1280, dim_out),
            torch.nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),    #Unflatten是将张量展平后再恢复成原来的形状,dim是要展平的维度，unflattened_size指定了展开后的大小
        )

        self.s0 = torch.nn.Sequential(
            torch.nn.GroupNorm(num_groups=32, num_channels=dim_in, eps=1e-05, affine=True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1),
        )

        self.s1 = torch.nn.Sequential(
            torch.nn.GroupNorm(num_groups=32, num_channels=dim_out, eps=1e-05, affine=True),
            torch.nn.SiLU(),
            torch.nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1),
        )

        self.res = None
        if dim_in != dim_out:
            self.res = torch.nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)

    def forward(self, x, time):
        #x: [1, 320, 64, 64]
        #time: [1, 1280]

        res = x

        # [1, 1280] -> [1, 640, 1, 1]
        time = self.time(time)

        # [1, 320, 64, 64] -> [1, 640, 64, 64]
        x = self.s0(x) + time

        #维度不变
        # [1, 640, 64, 64]
        x = self.s1(x)

        # [1, 320, 64, 64] -> [1, 640, 64, 64]
        if self.res:
            res = self.res(res)

        #维度不变
        # [1, 640, 32, 32]
        x = res + x

        return x

Resnet(320, 640)(torch.randn(1, 320, 64, 64), torch.randn(1, 1280)).shape

torch.Size([1, 640, 64, 64])

自注意力（Self-Attention）
自注意力模型中，注意力权重是在同一个序列的不同位置之间计算的。换句话说，序列内部的元素相互计算注意力分数。
它通常用于捕捉序列内的长距离依赖关系。例如，在Transformer模型中，自注意力允许模型在处理序列中的每个元素时考虑到序列中的其他元素。
在自注意力中，查询（Query）、键（Key）和值（Value）通常来自同一个输入源。

交叉注意力（Cross-Attention）
交叉注意力通常涉及两个不同的序列或信息源。在这种情况下，一个序列的元素（查询）会和另一个序列的元素（键和值）计算注意力分数。
这种类型的注意力机制常用于任务中需要对两个不同的输入进行交互和比较的情况，如在机器翻译中，源语言和目标语言之间的注意力，或者在图像文字处理任务中，图像特征和文字特征之间的注意力。
在交叉注意力中，查询（Query）来自于一个输入源，而键（Key）和值（Value）来自于另一个输入源。

操作 [1, 4096, 320] -> [8, 4096, 40] 将一个大的特征维度（在这里是320）分割成多个较小的特征空间（在这里是40）。每个头在自己的特征子空间中处理序列，然后最后的输出会被整合起来。这种分割方法如下：

为什么分散：原始的Transformer模型引入了多头注意力机制来使模型能够在不同的表示子空间中捕捉到不同类型的信息。这是基于这样的假设：不同的注意力头可以学习到不同方面的特征。

计算细节：在您的代码中，原始特征维度是320。当您分割成8个头时，每个头的特征维度变为 320 / 8 = 40。这意味着每个头可以关注输入序列特征的不同方面。

效果：多头注意力通常可以提高模型的性能，因为它允许模型在不同的表示子空间中并行地捕捉到更加复杂和细粒度的信息。

最后输出：最后，经过注意力计算后的多个头的输出会被重新组合（concatenated）或相加（summed），以形成单个输出张量，这个张量的特征维度通常会恢复到原始的维度大小。在您的代码中，这一步是通过 reshape(x) 函数的第二个定义完成的，将 [8, 4096, 40] 重新整合为 [1, 4096, 320]。

在深度学习中，mask_attention 或一般的注意力掩码（attention masking）被用于多种情况，主要是为了防止模型在计算注意力时考虑某些不应该被考虑的信息。以下是使用注意力掩码的几种常见场景：

填充（Padding）处理：在处理不等长的序列数据时，通常会使用填充来使所有序列达到相同的长度。这样做可以让批处理变得可行。然而，在计算注意力时，我们不希望模型将填充的部分考虑进去，因为它们不包含有用的信息。因此，我们使用掩码来指示模型忽略这些填充位置。

因果（Causal）或序列（Sequential）掩码：在生成文本或处理时间序列数据时，模型在预测位置 i 的输出时，只应该使用位置 i 之前的信息。为了确保模型不会"看到未来"，我们使用一个因果掩码，该掩码会覆盖序列中位置 i 之后的所有位置。

解码器-编码器注意力掩码：在序列到序列模型（如Transformer模型）中，解码器的每一步都需要关注编码器的输出。如果编码器的输出包含了填充，我们需要确保解码器不会将这些填充考虑进注意力计算中。

特定任务的掩码：有时，根据特定任务的需求，可能需要设计特殊的掩码。例如，在一些对齐或匹配任务中，可能只希望模型关注特定的输入对。

在实际实现中，掩码通常是一个与输入序列形状相同的张量，掩码张量的元素通常由0和负无穷（或非常小的值）组成。在应用softmax函数之前，掩码张量会与注意力分数张量相加。由于softmax函数的性质，任何负无穷的元素都会导致对应位置的输出接近于0，从而有效地"掩盖"了这些位置的影响。

In [17]:
class CrossAttention(torch.nn.Module):
    def __init__(self, dim_q, dim_kv):
        #dim_q -> 320
        #dim_kv -> 768

        super().__init__()

        self.dim_q = dim_q
        self.q = torch.nn.Linear(dim_q, dim_q, bias=False)
        self.k = torch.nn.Linear(dim_kv, dim_q, bias=False)
        self.v = torch.nn.Linear(dim_kv, dim_q, bias=False)

        self.out = torch.nn.Linear(dim_q, dim_q)

    def forward(self, q, kv):
        # q: [1, 4096, 320]
        # kv: [1, 77, 768]

        # [1, 4096, 320] -> [1, 4096, 320]
        q = self.q(q)
        # [1, 77, 768] -> [1, 77, 320]
        k = self.k(kv)
        # [1, 77, 768] -> [1, 77, 320]
        v = self.v(kv)

        def reshape(x):
            # x: [1, 4096, 320]
            b, lens, dim = x.shape

            # [1, 4096, 320] -> [1, 4096, 8, 40]
            x = x.reshape(b, lens, 8, dim // 8)

            # [1, 4096, 8, 40] -> [1, 8, 4096, 40]
            x = x.transpose(1, 2)

            # [1, 8, 4096, 40] -> [8, 4096, 40]
            x = x.reshape(b * 8, lens, dim // 8)

            return x

        #[1, 4096, 320] -> [8, 4096, 40]    多头注意力，8头
        q = reshape(q)
        #[1, 77, 320] -> [8, 77, 40]
        k = reshape(k)
        #[1, 77, 320] -> [8, 77, 40]
        v = reshape(v)

        #[8, 4096, 40] * [8, 40, 77] -> [8, 4096, 77]
        #atten = q.bmm(k.transpose(1, 2)) * (self.dim_q//8)**-0.5

        #从数学上是等价的，但是在实际计算时会产生很小的误差
        atten = torch.baddbmm(
            torch.empty(q.shape[0], q.shape[1], k.shape[1], device=q.device),
            q,
            k.transpose(1, 2),
            beta=0,
            alpha=(self.dim_q//8)**-0.5,
        )

        atten = atten.softmax(dim=-1)

        #[8, 4096, 77] * [8, 77, 40] -> [8, 4096, 40]
        atten = atten.bmm(v)

        def reshape(x):
            #x: [8, 4096, 40]
            b, lens, dim = x.shape

            #[8, 4096, 40] -> [1, 8, 4096, 40]
            x = x.reshape(b//8, 8, lens, dim)

            #[1, 8, 4096, 40] -> [1, 4096, 8, 40]
            x = x.transpose(1, 2)

            #[1, 4096, 320]
            x = x.reshape(b//8, lens, dim * 8)

            return x

        #[8, 4096, 40] -> [1, 4096, 320]
        atten = reshape(atten)

        #[1, 4096, 320] -> [1, 4096, 320]
        atten = self.out(atten)

        return atten

CrossAttention(320, 768)(torch.randn(1, 4096, 320), torch.randn(1, 77, 768)).shape

torch.Size([1, 4096, 320])

In [18]:
class Transformer(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.dim = dim

        # in
        self.norm_in = torch.nn.GroupNorm(
            num_groups=32, num_channels=dim, eps=1e-6, affine=True
        )

        self.cnn_in = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)

        # atten
        self.norm_atten0 = torch.nn.LayerNorm(dim, elementwise_affine=True)
        self.atten1 = CrossAttention(dim, dim)
        self.norm_atten1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
        self.atten2 = CrossAttention(dim, 768)

        # act
        self.norm_act = torch.nn.LayerNorm(dim, elementwise_affine=True)
        self.fc0 = torch.nn.Linear(dim, dim * 8)
        self.act = torch.nn.GELU()
        self.fc1 = torch.nn.Linear(dim * 4, dim)

        # out
        self.cnn_out = torch.nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)

    def forward(self, q, kv):
        # q: [1, 320, 64, 64]
        # kv: [1, 77, 768]
        b, _, h, w = q.shape
        res1 = q

        # ----in----
        # 维度不变
        # [1, 320, 64, 64]
        q = self.cnn_in(self.norm_in(q))

        # [1, 320, 64, 64] -> [1, 64, 64, 320] -> [1, 4096, 320]
        q = q.permute(0, 2, 3, 1).reshape(b, h * w, self.dim)

        # ----atten----
        # 维度不变
        # [1, 4096, 320]
        q = self.atten1(q=self.norm_atten0(q), kv=self.norm_atten0(q)) + q
        q = self.atten2(q=self.norm_atten1(q), kv=kv) + q

        # ----act----
        # [1, 4096, 320]
        res2 = q

        # [1, 4096, 320] -> [1, 4096, 2560]
        q = self.fc0(self.norm_act(q))

        # 1280
        d = q.shape[2] // 2

        # [1, 4096, 1280] * [1, 4096, 1280] -> [1, 4096, 1280]
        # 这样做的意义是允许网络学习如何在不同特征表示之间进行信息流动的控制。
        # 通过将特征空间一分为二，一半用来生成门控信号，另一半通过门控信号进行调节，
        # 这种方法允许模型在不同的子空间中学习不同的特征表示，从而可能提高模型的表达能力。
        q = q[:, :, :d] * self.act(q[:, :, d:])     #相当于FFN

        # [1, 4096, 1280] -> [1, 4096, 320]
        q = self.fc1(q) + res2

        # ----out----
        # [1, 4096, 320] -> [1, 64, 64, 320] -> [1, 320, 64, 64]
        q = q.reshape(b, h, w, self.dim).permute(0, 3, 1, 2).contiguous()

        # 维度不变
        # [1, 320, 64 ,64]
        q = self.cnn_out(q) + res1
        return q


Transformer(320)(torch.randn(1, 320, 64, 64), torch.randn(1, 77, 768)).shape

torch.Size([1, 320, 64, 64])

In [19]:
class DownBlock(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.tf0 = Transformer(dim_out)
        self.res0 = Resnet(dim_in, dim_out)

        self.tf1 = Transformer(dim_out)
        self.res1 = Resnet(dim_out, dim_out)

        self.out = torch.nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=2, padding=1)

    def forward(self, out_vae, out_encoder, time):
        outs = []

        out_vae = self.res0(out_vae, time)
        out_vae = self.tf0(out_vae, out_encoder)
        outs.append(out_vae)

        out_vae = self.res1(out_vae, time)
        out_vae = self.tf1(out_vae, out_encoder)
        outs.append(out_vae)

        out_vae = self.out(out_vae)
        outs.append(out_vae)

        return out_vae, outs


DownBlock(320, 640)(
    torch.randn(1, 320, 32, 32),
    torch.randn(1, 77, 768),
    torch.randn(1, 1280),
)[0].shape

torch.Size([1, 640, 16, 16])

In [20]:
class UpBlock(torch.nn.Module):
    def __init__(self, dim_in, dim_out, dim_prev, add_up):
        super().__init__()

        self.res0 = Resnet(dim_out + dim_prev, dim_out)
        self.res1 = Resnet(dim_out + dim_out, dim_out)
        self.res2 = Resnet(dim_in + dim_out, dim_out)

        self.tf0 = Transformer(dim_out)
        self.tf1 = Transformer(dim_out)
        self.tf2 = Transformer(dim_out)

        self.out = None
        if add_up:
            self.out = torch.nn.Sequential(
                torch.nn.Upsample(scale_factor=2, mode="nearest"),
                torch.nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1),
            )

    def forward(self, out_vae, out_encoder, time, out_down):
        out_vae = self.res0(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.tf0(out_vae, out_encoder)

        out_vae = self.res1(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.tf1(out_vae, out_encoder)

        out_vae = self.res2(torch.cat([out_vae, out_down.pop()], dim=1), time)
        out_vae = self.tf2(out_vae, out_encoder)

        if self.out:
            out_vae = self.out(out_vae)

        return out_vae


UpBlock(320, 640, 1280, True)(
    torch.randn(1, 1280, 32, 32),
    torch.randn(1, 77, 768),
    torch.randn(1, 1280),
    [
        torch.randn(1, 320, 32, 32),
        torch.randn(1, 640, 32, 32),
        torch.randn(1, 640, 32, 32),
    ],
).shape

torch.Size([1, 640, 64, 64])

In [21]:
class UNet(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #in
        self.in_vae = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)

        self.in_time = torch.nn.Sequential(
            torch.nn.Linear(320, 1280),
            torch.nn.SiLU(),
            torch.nn.Linear(1280, 1280),
        )

        #down
        self.down_block0 = DownBlock(320, 320)
        self.down_block1 = DownBlock(320, 640)
        self.down_block2 = DownBlock(640, 1280)

        self.down_res0 = Resnet(1280, 1280)
        self.down_res1 = Resnet(1280, 1280)

        #mid
        self.mid_res0 = Resnet(1280, 1280)
        self.mid_tf = Transformer(1280)
        self.mid_res1 = Resnet(1280, 1280)

        #up
        self.up_res0 = Resnet(2560, 1280)
        self.up_res1 = Resnet(2560, 1280)
        self.up_res2 = Resnet(2560, 1280)

        self.up_in = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.Conv2d(1280, 1280, kernel_size=3, padding=1),
        )

        self.up_block0 = UpBlock(640, 1280, 1280, True)
        self.up_block1 = UpBlock(320, 640, 1280, True)
        self.up_block2 = UpBlock(320, 320, 640, False)

        #out
        self.out = torch.nn.Sequential(
            torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),
            torch.nn.SiLU(),
            torch.nn.Conv2d(320, 4, kernel_size=3, padding=1),
        )

    def forward(self, out_vae, out_encoder, time):
        #out_vae -> [2, 4, 64, 64]
        #out_encoder -> [2, 77, 768]
        #time -> [1]

        #----in----
        #[2, 4, 64, 64] -> [2, 320, 64, 64]
        out_vae = self.in_vae(out_vae)

        def get_time_embed(t):
            #-9.210340371976184 = -math.log(10000)
            e = torch.arange(160) * -9.210340371976184 / 160
            e = e.exp().to(t.device) * t

            #[160+160] -> [320] -> [1, 320]
            e = torch.cat([e.cos(), e.sin()]).unsqueeze(dim=0)

            return e

        #[1] -> [1, 320]
        time = get_time_embed(time)
        #[1, 320] -> [1, 1280]
        time = self.in_time(time)

        #----down----
        #[2, 320, 64, 64]
        #[2, 320, 64, 64]
        #[2, 320, 64, 64]
        #[2, 320, 32, 32]
        #[2, 640, 32, 32]
        #[2, 640, 32, 32]
        #[2, 640, 16, 16]
        #[2, 1280, 16, 16]
        #[2, 1280, 16, 16]
        #[2, 1280, 8, 8]
        #[2, 1280, 8, 8]
        #[2, 1280, 8, 8]
        out_down = [out_vae]

        #[2, 320, 64, 64],[2, 77, 768],[1, 1280] -> [2, 320, 32, 32]
        #out -> [2, 320, 64, 64],[2, 320, 64, 64][2, 320, 32, 32]
        out_vae, out = self.down_block0(out_vae=out_vae,
                                        out_encoder=out_encoder,
                                        time=time)
        out_down.extend(out)

        #[2, 320, 32, 32],[2, 77, 768],[1, 1280] -> [2, 640, 16, 16]
        #out -> [2, 640, 32, 32],[2, 640, 32, 32],[2, 640, 16, 16]
        out_vae, out = self.down_block1(out_vae=out_vae,
                                        out_encoder=out_encoder,
                                        time=time)
        out_down.extend(out)

        #[2, 640, 16, 16],[2, 77, 768],[1, 1280] -> [2, 1280, 8, 8]
        #out -> [2, 1280, 16, 16],[2, 1280, 16, 16],[2, 1280, 8, 8]
        out_vae, out = self.down_block2(out_vae=out_vae,
                                        out_encoder=out_encoder,
                                        time=time)
        out_down.extend(out)

        #[2, 1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.down_res0(out_vae, time)
        out_down.append(out_vae)

        #[2, 1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.down_res1(out_vae, time)
        out_down.append(out_vae)

        #----mid----
        #[2, 1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.mid_res0(out_vae, time)

        #[2, 1280, 8, 8],[2, 77, 768] -> [2, 1280, 8, 8]
        out_vae = self.mid_tf(out_vae, out_encoder)

        #[2, 1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.mid_res1(out_vae, time)

        #----up----
        #[2, 1280+1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.up_res0(torch.cat([out_vae, out_down.pop()], dim=1),
                                time)

        #[2, 1280+1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.up_res1(torch.cat([out_vae, out_down.pop()], dim=1),
                                time)

        #[2, 1280+1280, 8, 8],[1, 1280] -> [2, 1280, 8, 8]
        out_vae = self.up_res2(torch.cat([out_vae, out_down.pop()], dim=1),
                                time)

        #[2, 1280, 8, 8] -> [2, 1280, 16, 16]
        out_vae = self.up_in(out_vae)

        #[2, 1280, 16, 16],[2, 77, 768],[1, 1280] -> [2, 1280, 32, 32]
        #out_down -> [2, 640, 16, 16],[2, 1280, 16, 16],[2, 1280, 16, 16]
        out_vae = self.up_block0(out_vae=out_vae,
                                out_encoder=out_encoder,
                                time=time,
                                out_down=out_down)

        #[2, 1280, 32, 32],[2, 77, 768],[1, 1280] -> [2, 640, 64, 64]
        #out_down -> [2, 320, 32, 32],[2, 640, 32, 32],[2, 640, 32, 32]
        out_vae = self.up_block1(out_vae=out_vae,
                                out_encoder=out_encoder,
                                time=time,
                                out_down=out_down)

        #[2, 640, 64, 64],[2, 77, 768],[1, 1280] -> [2, 320, 64, 64]
        #out_down -> [2, 320, 64, 64],[2, 320, 64, 64],[2, 320, 64, 64]
        out_vae = self.up_block2(out_vae=out_vae,
                                out_encoder=out_encoder,
                                time=time,
                                out_down=out_down)

        #----out----
        #[2, 320, 64, 64] -> [2, 4, 64, 64]
        out_vae = self.out(out_vae)

        return out_vae


UNet()(torch.randn(2, 4, 64, 64), torch.randn(2, 77, 768),
        torch.LongTensor([26])).shape

torch.Size([2, 4, 64, 64])

In [22]:
from diffusers import UNet2DConditionModel

#加载预训练模型的参数
params = UNet2DConditionModel.from_pretrained(
    './Diffusion_model', subfolder='unet')

unet = UNet()

#in
unet.in_vae.load_state_dict(params.conv_in.state_dict())
unet.in_time[0].load_state_dict(params.time_embedding.linear_1.state_dict())
unet.in_time[2].load_state_dict(params.time_embedding.linear_2.state_dict())


#down
def load_tf(model, param):
    model.norm_in.load_state_dict(param.norm.state_dict())
    model.cnn_in.load_state_dict(param.proj_in.state_dict())

    model.atten1.q.load_state_dict(
        param.transformer_blocks[0].attn1.to_q.state_dict())
    model.atten1.k.load_state_dict(
        param.transformer_blocks[0].attn1.to_k.state_dict())
    model.atten1.v.load_state_dict(
        param.transformer_blocks[0].attn1.to_v.state_dict())
    model.atten1.out.load_state_dict(
        param.transformer_blocks[0].attn1.to_out[0].state_dict())

    model.atten2.q.load_state_dict(
        param.transformer_blocks[0].attn2.to_q.state_dict())
    model.atten2.k.load_state_dict(
        param.transformer_blocks[0].attn2.to_k.state_dict())
    model.atten2.v.load_state_dict(
        param.transformer_blocks[0].attn2.to_v.state_dict())
    model.atten2.out.load_state_dict(
        param.transformer_blocks[0].attn2.to_out[0].state_dict())

    model.fc0.load_state_dict(
        param.transformer_blocks[0].ff.net[0].proj.state_dict())

    model.fc1.load_state_dict(
        param.transformer_blocks[0].ff.net[2].state_dict())

    model.norm_atten0.load_state_dict(
        param.transformer_blocks[0].norm1.state_dict())
    model.norm_atten1.load_state_dict(
        param.transformer_blocks[0].norm2.state_dict())
    model.norm_act.load_state_dict(
        param.transformer_blocks[0].norm3.state_dict())

    model.cnn_out.load_state_dict(param.proj_out.state_dict())


def load_res(model, param):
    model.time[1].load_state_dict(param.time_emb_proj.state_dict())

    model.s0[0].load_state_dict(param.norm1.state_dict())
    model.s0[2].load_state_dict(param.conv1.state_dict())

    model.s1[0].load_state_dict(param.norm2.state_dict())
    model.s1[2].load_state_dict(param.conv2.state_dict())

    if isinstance(model.res, torch.nn.Module):
        model.res.load_state_dict(param.conv_shortcut.state_dict())


def load_down_block(model, param):
    load_tf(model.tf0, param.attentions[0])
    load_tf(model.tf1, param.attentions[1])

    load_res(model.res0, param.resnets[0])
    load_res(model.res1, param.resnets[1])

    model.out.load_state_dict(param.downsamplers[0].conv.state_dict())


load_down_block(unet.down_block0, params.down_blocks[0])
load_down_block(unet.down_block1, params.down_blocks[1])
load_down_block(unet.down_block2, params.down_blocks[2])

load_res(unet.down_res0, params.down_blocks[3].resnets[0])
load_res(unet.down_res1, params.down_blocks[3].resnets[1])

#mid
load_tf(unet.mid_tf, params.mid_block.attentions[0])
load_res(unet.mid_res0, params.mid_block.resnets[0])
load_res(unet.mid_res1, params.mid_block.resnets[1])

#up
load_res(unet.up_res0, params.up_blocks[0].resnets[0])
load_res(unet.up_res1, params.up_blocks[0].resnets[1])
load_res(unet.up_res2, params.up_blocks[0].resnets[2])
unet.up_in[1].load_state_dict(
    params.up_blocks[0].upsamplers[0].conv.state_dict())


def load_up_block(model, param):
    load_tf(model.tf0, param.attentions[0])
    load_tf(model.tf1, param.attentions[1])
    load_tf(model.tf2, param.attentions[2])

    load_res(model.res0, param.resnets[0])
    load_res(model.res1, param.resnets[1])
    load_res(model.res2, param.resnets[2])

    if isinstance(model.out, torch.nn.Module):
        model.out[1].load_state_dict(param.upsamplers[0].conv.state_dict())


load_up_block(unet.up_block0, params.up_blocks[1])
load_up_block(unet.up_block1, params.up_blocks[2])
load_up_block(unet.up_block2, params.up_blocks[3])

#out
unet.out[0].load_state_dict(params.conv_norm_out.state_dict())
unet.out[2].load_state_dict(params.conv_out.state_dict())

<All keys matched successfully>

In [23]:
out_vae = torch.randn(1, 4, 64, 64)
out_encoder = torch.randn(1, 77, 768)
time = torch.LongTensor([26])

a = unet(out_vae=out_vae, out_encoder=out_encoder, time=time)
b = params(out_vae, time, out_encoder).sample

(a == b).all()

tensor(True)