In [1]:
import torch
from torch import nn
from typing import Union,Tuple,List
from utils.activate import Swish
from utils.TimeEmbedding import TimeEmbedding
from utils.DownBlock import DownBlock
from utils.DownSample_and_UpSample import DownSample,Upsample
from utils.MiddleBlock import MiddleBlock
from utils.UpBlock import UpBlock

In [2]:
class UNet(nn.Module):
    """
    DDPM UNet的主体架构
    """
    def __init__(self,image_channels:int=3,n_channels:int=64,
                 ch_mults:Union[Tuple[int,...],List[int]]=(1,2,2,4),
                 is_attn:Union[Tuple[bool,...],List[int]]=(False,False,True,True),
                 n_blocks:int=2):
        """
        
        :param image_channels: 原始图片的channel数
        :param n_channels: 在进UNet之前，会对原始图片做一次初步卷积，这是卷积完的通道数
        :param ch_mults: Encoder下采样每一层out_channel的倍数
        :param is_attn: 在Encoder下采样/Decoder上采样的每一层是否要在CNN做特征提取后再引入Attention
        :param n_blocks: 在Encoder/Decoder的每一层，需要用多少个上采样以及下采样块
        """
        super().__init__()
        
        #在Encoder下采样/Decoder上采样过程中图像依次缩小/放大
        #每次变动会产生一个新的图像分辨率
        #这里指的就是不同图像分辨率的个数，也可以理解为Encoder/Decoder的层数
        n_resolutions=len(ch_mults)
        
        #对原始图像做预处理
        self.image_proj=nn.Conv2d(in_channels=image_channels,out_channels=n_channels,kernel_size=3,padding=1)
        
        #时间戳生成
        self.time_emb=TimeEmbedding(n_channels*4)
        
        #定义Encoder部分
        
        #down列表中的每个元素表示Encoder的每一层
        down=[]
        
        out_channels=in_channels=n_channels
        
        for i in range(n_resolutions):
            
            #根据设定好的规则，得到这一层的out_channel
            out_channels=in_channels*ch_mults[i]
            
            #每一层有几个块
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels,out_channels,n_channels*4,is_attn[i]))
                in_channels=out_channels
            
            #对Encoder来言，每一层结束后都做一次下采样，但Encoder的最后一层不做下采样
            if i<n_resolutions-1:
                down.append(DownSample(in_channels))
        
        #将列表转化为网络
        self.down=nn.ModuleList(down)
        
        #定义Middle部分
        self.middle=MiddleBlock(out_channels,n_channels*4,)
        
        #定义Decoder部分
        up=[]
        in_channels=out_channels
        
        for i in reversed(range(n_resolutions)):
            out_channels=in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels,out_channels,n_channels*4,is_attn[i]))
            
            out_channels=in_channels//ch_mults[i]
            up.append(UpBlock(in_channels,out_channels,n_channels*4,is_attn[i]))
            in_channels=out_channels
            if i>0:
                up.append(Upsample(in_channels))
        self.up=nn.ModuleList(up)
        
        #定义group_norm，激活函数以及最后一层CNN，将最上一层特征图还原为原始图
        self.norm=nn.GroupNorm(8,n_channels)
        self.act=Swish()
        self.final=nn.Conv2d(in_channels,image_channels,kernel_size=3,padding=1)
        
    
    def forward(self,x:torch.Tensor,t:torch.Tensor):
        """
        
        :param x: 输入数据xt，尺寸大小为(batch_size,in_channels,height,width)
        :param t: 输入数据t，尺寸大小为(batch_size)
        """
        t=self.time_emb(t)
        
        #对原始图片先做预处理
        x=self.image_proj(x)
        
        #Encoder部分
        h=[x]
        
        for m in self.down:
            x=m(x,t)
            h.append(x)
        
        #中间模块
        x=self.middle(x,t)
        
        #解码器部分
        for m in self.up:
            if isinstance(m,Upsample):
                x=m(x,t)
            else:
                s=h.pop()
                x=torch.cat((x,s),dim=1)
                x=m(x,t)
        
        return self.final(self.act(self.norm(x)))

In [3]:
x=torch.randn(size=[256,3,32,32])
t=torch.randint(low=0,high=100,size=[256])
print(t)

tensor([12, 93, 98, 80,  9,  2, 91, 19,  0, 43, 28, 60,  1, 33, 20, 15,  6, 67,
        59, 32, 20, 30, 92, 36, 25, 10, 80, 16, 90, 99, 40, 61, 82, 90,  1, 87,
        84, 31,  7, 23, 51, 25, 86,  2, 53, 18, 12, 92, 19, 98, 21, 21, 45, 52,
        42, 47, 37, 53, 73, 49, 23, 61, 34, 67, 60, 24, 53, 18, 87, 86, 95, 20,
        74, 14, 62, 58, 36, 25, 72, 53, 50, 35, 89, 10, 68, 86, 16, 59, 17, 18,
        34, 94, 85, 71, 67, 10, 15, 88, 28, 73, 36, 76,  4, 28, 95, 30, 46, 45,
        60, 97, 21, 75,  7, 45, 58, 73, 23,  4, 31, 51, 49,  2,  0,  0, 60, 52,
        16, 58,  6, 55, 82, 37, 93, 75, 62, 67, 69, 62,  3, 56, 49,  0, 95, 29,
        27,  3, 68,  5, 72, 43, 21, 82, 75, 72, 84, 52, 52, 35, 40,  8, 96, 62,
        56, 31, 42, 70, 10, 53, 79, 49, 21, 21, 13, 13, 92, 15, 44, 87, 95, 72,
        81, 42, 57, 57, 58, 31, 29, 16, 52, 22, 20, 18,  1, 91, 87, 26, 21, 31,
        28,  2, 44, 35, 58, 64, 67, 49, 26, 12, 31, 28,  9, 27, 66, 58, 99, 68,
        47, 23, 75, 56, 45, 56, 50, 99, 

In [6]:
unet=UNet()
print(unet(x,t).shape)

torch.Size([256, 64, 1, 1])
torch.Size([256, 64, 1, 1])
torch.Size([256, 128, 1, 1])
torch.Size([256, 128, 1, 1])
torch.Size([256, 256, 1, 1])
torch.Size([256, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([256, 3, 32, 32])
