In [6]:
# cross attention + multi-head attention 구현
# transformer 블럭 구현 - self > cross > FeedForward
import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange

class CrossAttention(nn.Module):
    """
    This is used for both self attention and cross attention mechanism
    """
    def __init__(self, in_c, num_heads=8, is_cross=False, context_dim=None):
        self.num_heads = num_heads
        self.is_cross = is_cross
        self.context_dim = context_dim

        self.query = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
        self.key = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
        self.value = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)

    def forward(self, x, context=None):
        B, C, H, W = x.shape
        if self.is_cross and context:
            Q = self.query(context)
        else:
            Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        Q = Q.permute(0, 2, 3, 1).view(B, H*W, C)
        K = K.view(B, C, H*W)
        V = V.permute(0, 2, 3, 1).view(B, H*W, C)

        attention_weight = F.softmax(torch.bmm(Q, K)/(int(C) ** 0.5))
        output = torch.bmm(attention_weight, V)
        output = output.view(B, H, W, C).permute(0, 3, 1, 2)

        return output

class CrossAttention2(nn.Module):
    def __init__(self, in_channels, dim_head=64, n_heads=8, context_dim=None):
        super(CrossAttention2, self).__init__()
        self.n_heads = n_heads
        self.dim_head = dim_head
        self.in_channels = in_channels

        self.context_dim = context_dim if context_dim is not None else in_channels
        inner_dim = dim_head * n_heads

        self.to_q = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
        self.to_k = nn.Conv2d(self.context_dim, inner_dim, kernel_size=1, stride=1, padding=0)
        self.to_v = nn.Conv2d(self.context_dim, inner_dim, kernel_size=1, stride=1, padding=0)

        self.w = nn.Sequential(
            nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0),
            nn.Dropout(0.)
        )
        self.scale = dim_head ** 0.5


    def forward(self, x, context=None):
        bs, c, h, w = x.shape
        q = self.to_q(x)
        context = context if context != None else x
        k = self.to_k(context)
        v = self.to_v(context) # bs, dim_head * n_heads, h, w

        q = q.view(bs, self.n_heads, self.dim_head, h*w).permute(0, 1, 3, 2) # h*w는 sequence length, token 수로 해석된다. pixel간의 관계가 중요함.
        k = k.view(bs, self.n_heads, self.dim_head, h*w)
        v = v.view(bs, self.n_heads, self.dim_head, h*w)

        attention_weight = F.softmax(torch.matmul(q, k)/self.scale, dim=-1) # bs, n_heads, h*w, h*w
        output = torch.matmul(v, attention_weight) # bs, n_heads, dim_head, h*w 
        # torch.bmm은 batch단위 matnul. 따라서 입력이 3차원 이어야함

        output = rearrange(output, 'b n d (h w) -> b (n d) h w', h=h, w=w)
        output = self.w(output)

        return output


class TransformerBlock(nn.Module):
    """
    Transformer block
    """
    def __init__(self, in_channels, resolution, n_heads=8, context_dim=None, mult=2):
        super(TransformerBlock, self).__init__()
        # self attn
        # cross attn
        # feedforward
        inner_dim = in_channels * mult
        
        self.norm1 = nn.LayerNorm([in_channels, resolution, resolution])
        self.norm2 = nn.LayerNorm([in_channels, resolution, resolution])
        self.norm3 = nn.LayerNorm([in_channels, resolution, resolution])
        self.self_attn = CrossAttention2(in_channels, n_heads=8)
        self.cross_attn = CrossAttention2(in_channels, n_heads=8, context_dim=context_dim)
        self.feed_forward = nn.Sequential(
            nn.Conv2d(in_channels, inner_dim, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.Conv2d(inner_dim, in_channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x, temb=None, context=None):
        x = self.self_attn(self.norm1(x)) + x
        x = self.cross_attn(self.norm2(x), context=context) + x
        x = self.feed_forward(self.norm3(x)) + x

        return x

In [7]:
device = 'cuda'

ctxt = torch.randn((16, 512, 32, 32)).to(device)
inp = torch.randn((16, 320, 32, 32)).to(device) # b, c, h, w

cr1 = CrossAttention2(320, n_heads=8, context_dim=512).to(device)
trans_block = TransformerBlock(320, resolution=32, n_heads=8, mult=2, context_dim=512).to(device)

In [8]:
attn_out = trans_block(inp, None, ctxt)
print(attn_out.shape)

torch.Size([16, 320, 32, 32])


In [9]:
from block import ResBlock
import torch.nn as nn

inc = 320
time_dim = inc * 4
mc = 320

middleblocks = nn.ModuleList([
    ResBlock(inc, mc, time_emb_dim=time_dim),
    TransformerBlock(mc, resolution=32, n_heads=8, context_dim=512, mult=2),
    ResBlock(mc, mc, time_emb_dim=time_dim)
])

In [10]:
device = 'cuda'

middleblocks = middleblocks.to(device)
inp = torch.randn((16, 320, 32, 32)).to(device) # b, c, h, w
temb = torch.randn((16, time_dim)).to(device)
ctxt = torch.randn((16, 512, 32, 32)).to(device)

for module in middleblocks:
    out = module(inp, temb, ctxt)
    print(out.shape)

torch.Size([16, 320, 32, 32])
torch.Size([16, 320, 32, 32])
torch.Size([16, 320, 32, 32])


In [1]:
from block import ResBlock
from module import UpSample, DownSample, SinusoidalPositionalEmbedding
import torch.nn.functional as F
import torch.nn as nn
import torch
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from attention import TransformerBlock

class Unet(nn.Module):
    def __init__(
        self,
        mults=[1, 2, 4, 4],
        init_resolution=32,
        in_channels=4,
        model_channels=256,
        num_res_blocks=2,
        attention_resolutions=[0, 1, 2],
        context_dim=512,
    ):
        super(Unet, self).__init__()
        
        self.init_conv = nn.Conv2d(in_channels, model_channels, kernel_size=3, stride=1, padding=1)
        self.last_conv = nn.Conv2d(model_channels, in_channels, kernel_size=3, stride=1, padding=1)
        
        time_dim = model_channels*4
        sinu_pos_embedding = SinusoidalPositionalEmbedding(model_channels, 10000)
        
        self.time_mlp = nn.Sequential(
            sinu_pos_embedding,
            nn.Linear(model_channels, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        chn = model_channels
        resolution = init_resolution
        self.downblocks = nn.ModuleList()
        in_channels = model_channels
        for level, mult in enumerate(mults):
            current_block = []
            out_channels = model_channels * mult # it's not cumulative
            for _ in range(num_res_blocks):
                self.downblocks.append(
                    ResBlock(in_channels, out_channels, time_emb_dim=time_dim)
                )
                if level in attention_resolutions:
                    self.downblocks.append(
                        TransformerBlock(out_channels, resolution=resolution, n_heads=8, context_dim=context_dim, mult=2)
                    )
                in_channels = out_channels
            if level != len(mults)-1:
                self.downblocks.append(
                    DownSample(out_channels)
                )
                resolution = int(resolution / 2) # Downsampled
            print("resolution : ", resolution)
            # self.downblocks.append(nn.Sequential(*current_block))
        
        middle_channel = out_channels*2
        self.middleblocks = nn.ModuleList([
            ResBlock(out_channels, middle_channel, time_emb_dim=time_dim),
            TransformerBlock(middle_channel, resolution=resolution, n_heads=8, context_dim=context_dim, mult=2),
            ResBlock(middle_channel, middle_channel, time_emb_dim=time_dim)
        ])

        in_channels=middle_channel
        self.upsamples = nn.ModuleList()
        for level, mult in enumerate(mults[::-1]):
            if level != len(mults)-1:
                out_channels = model_channels*mult
                self.upsamples.append(
                    UpSample(in_channels, out_channels)
                )
                in_channels = out_channels
        
        self.upblocks = nn.ModuleList()
        for level, mult in enumerate(mults[::-1]):
            out_channels = model_channels*mult
            current_block = nn.ModuleList()
            for _ in range(num_res_blocks):
                current_block.append(
                    ResBlock(out_channels*2, out_channels, time_emb_dim=time_dim)
                )
                if len(mults) - level - 1 in attention_resolutions:
                    current_block.append(
                        TransformerBlock(out_channels, resolution=resolution, n_heads=8, context_dim=context_dim, mult=2)
                    )
                in_channels = out_channels
            if level != 0:
                resolution = int(resolution * 2) # Downsampled
            self.upblocks.append(current_block)
            print("resolution : ", resolution)
        
        print(len(self.downblocks))
        print(len(self.upblocks))
        print(len(self.upsamples))

    def forward(self, x, t, context=None):
        initial = self.init_conv(x)
        t_emb = self.time_mlp(t) # 같은 t_emb가 각 ResNet Block에 들어간다.
        
        x = self.init_conv(x)

        connections = []
        for i, layer in enumerate(self.downblocks):
            x = layer(x, t_emb, context)
            if layer.__class__.__name__ == "DownSample":
                connections.append(x)

        print("start middle : ", x.shape)
        for layer in self.middleblocks:
            x = layer(x, t_emb, context)
            print(f"middle - {i}, x : {x.shape}")

        print("start up : ", x.shape)
        for i in range(len(self.upblocks)):
            print("ii  : ", i)
            if i != 0:
                print("upsample -> ")
                x = self.upsamples[i](x, t_emb)
                print("upsample : ", x.shape)
            x = torch.concat((x, connections[::-1][i]), dim=1)
            for j in range(len(self.upblocks[i])):
                x = self.upblocks[i][j](x, t_emb, context)
            print(f"up - {i}, x : {x.shape}")

        x = self.last_conv(x)
        
        return x

In [2]:
unet = Unet()

resolution :  16
resolution :  8
resolution :  4
resolution :  4
resolution :  4
resolution :  8
resolution :  16
resolution :  32
17
4
3


In [3]:
device = 'cuda'
inp = torch.randn((16, 4, 32, 32)).to(device)
ctxt = torch.randn((16, 512)).to(device)
unet = unet.to(device)

In [4]:
t = torch.randint(1, 1001, (16,)).to(device)
print(t.shape)

torch.Size([16])


In [5]:
unet(inp, t, ctxt)

start middle :  torch.Size([16, 1024, 4, 4])
middle - 16, x : torch.Size([16, 2048, 4, 4])
middle - 16, x : torch.Size([16, 2048, 4, 4])
middle - 16, x : torch.Size([16, 2048, 4, 4])
start up :  torch.Size([16, 2048, 4, 4])


RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [2048] and input of shape [16, 3072, 4, 4]

In [7]:
ts = unet.time_mlp(t)

In [13]:
unet.downblocks[6].__class__.__name__

'TransformerBlock'

In [9]:
inp = torch.randn((16, 512, 16, 16)).to(device)
unet.downblocks[6](inp, ts, ctxt)

RuntimeError: shape '[16, 8, 64, 256]' is invalid for input of size 8388608