In [2]:
import torch
from torch import nn
from torchkeras import summary

from model.EmbeddingBlock import EmbeddingBlock
from model.Block import ConvBlock, DownSample, UpSample

input_shape=(1, 28, 28)
init_features=64
num_classes=6
embed_dim=128
init_features=64

in_channels = input_shape[0]
out_channels = input_shape[0]
features = init_features

x=torch.randn(1, *input_shape)
time=torch.tensor([1])
condition=torch.tensor([0])

In [3]:
embedder = EmbeddingBlock(embed_dim, embed_dim, num_classes)
embed=embedder(time,condition)

info=summary(embedder,time=time,condition=condition)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
PositionalEncoding-1                       [-1, 128]                    0
Linear-2                                   [-1, 128]               16,512
SiLU-3                                     [-1, 128]                    0
Linear-4                                   [-1, 128]               16,512
Embedding-5                                [-1, 128]                  768
Linear-6                                   [-1, 128]               16,512
SiLU-7                                     [-1, 128]                    0
Linear-8                                   [-1, 128]               16,512
Linear-9                                   [-1, 128]               32,896
Total params: 99,712
Trainable params: 99,712
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000000
Forward/backwa

In [4]:
head_conv = ConvBlock(in_channels, features)
init_x=head_conv(x)

info=summary(head_conv,input_data=x)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                            [-1, 64, 28, 28]                  640
GELU-2                              [-1, 64, 28, 28]                    0
GroupNorm-3                         [-1, 64, 28, 28]                  128
Conv2d-4                            [-1, 64, 28, 28]               36,928
GroupNorm-5                         [-1, 64, 28, 28]                  128
Total params: 37,824
Trainable params: 37,824
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000076
Forward/backward pass size (MB): 1.914062
Params size (MB): 0.144287
Estimated Total Size (MB): 2.058426
--------------------------------------------------------------------------


In [5]:
encoder01 = DownSample(features, features * 2, embed_dim)
enc01_x = encoder01(init_x, embed)

info=summary(encoder01,input_data=init_x,embed=embed)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
MaxPool2d-1                         [-1, 64, 14, 14]                    0
Conv2d-2                           [-1, 128, 14, 14]                8,320
GroupNorm-3                         [-1, 64, 14, 14]                  128
SiLU-4                              [-1, 64, 14, 14]                    0
Conv2d-5                           [-1, 128, 14, 14]               73,856
Linear-6                                   [-1, 256]               33,024
GroupNorm-7                        [-1, 128, 14, 14]                  256
SiLU-8                             [-1, 128, 14, 14]                    0
Conv2d-9                           [-1, 128, 14, 14]              147,584
Conv2d-10                          [-1, 128, 14, 14]              147,584
GELU-11                            [-1, 128, 14, 14]                    0
GroupNorm-12                       [-

In [7]:
encoder02 = DownSample(features * 2, features * 4, embed_dim)
enc02_x=encoder02(enc01_x, embed)

info=summary(encoder02,input_data=enc01_x,embed=embed)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
MaxPool2d-1                          [-1, 128, 7, 7]                    0
Conv2d-2                             [-1, 256, 7, 7]               33,024
GroupNorm-3                          [-1, 128, 7, 7]                  256
SiLU-4                               [-1, 128, 7, 7]                    0
Conv2d-5                             [-1, 256, 7, 7]              295,168
Linear-6                                   [-1, 512]               66,048
GroupNorm-7                          [-1, 256, 7, 7]                  512
SiLU-8                               [-1, 256, 7, 7]                    0
Conv2d-9                             [-1, 256, 7, 7]              590,080
Conv2d-10                            [-1, 256, 7, 7]              590,080
GELU-11                              [-1, 256, 7, 7]                    0
GroupNorm-12                         

In [8]:
bottleneck = nn.Sequential(
    ConvBlock(features * 4, features * 4),
    ConvBlock(features * 4, features * 4),
    ConvBlock(features * 4, features * 4),
)
bot_x=bottleneck(enc02_x)

info=summary(bottleneck,input_data=enc02_x)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                             [-1, 256, 7, 7]              590,080
GELU-2                               [-1, 256, 7, 7]                    0
GroupNorm-3                          [-1, 256, 7, 7]                  512
Conv2d-4                             [-1, 256, 7, 7]              590,080
GroupNorm-5                          [-1, 256, 7, 7]                  512
Conv2d-6                             [-1, 256, 7, 7]              590,080
GELU-7                               [-1, 256, 7, 7]                    0
GroupNorm-8                          [-1, 256, 7, 7]                  512
Conv2d-9                             [-1, 256, 7, 7]              590,080
GroupNorm-10                         [-1, 256, 7, 7]                  512
Conv2d-11                            [-1, 256, 7, 7]              590,080
GELU-12                              

In [9]:
decoder02 = UpSample(features * 4, features * 2, embed_dim)
dec02_x = decoder02(bot_x,embed,enc01_x)

info=summary(decoder02,input_data=bot_x,embed=embed,skip=enc01_x)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
ConvTranspose2d-1                  [-1, 128, 14, 14]              131,200
Conv2d-2                           [-1, 128, 14, 14]               32,896
GroupNorm-3                        [-1, 256, 14, 14]                  512
SiLU-4                             [-1, 256, 14, 14]                    0
Conv2d-5                           [-1, 128, 14, 14]              295,040
Linear-6                                   [-1, 256]               33,024
GroupNorm-7                        [-1, 128, 14, 14]                  256
SiLU-8                             [-1, 128, 14, 14]                    0
Conv2d-9                           [-1, 128, 14, 14]              147,584
Conv2d-10                          [-1, 128, 14, 14]              147,584
GELU-11                            [-1, 128, 14, 14]                    0
GroupNorm-12                       [-

In [10]:
decoder01 = UpSample(features * 2, features, embed_dim)
dec01_x=decoder01(dec02_x,embed,init_x)

info=summary(decoder01,input_data=dec02_x,embed=embed,skip=init_x)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
ConvTranspose2d-1                   [-1, 64, 28, 28]               32,832
Conv2d-2                            [-1, 64, 28, 28]                8,256
GroupNorm-3                        [-1, 128, 28, 28]                  256
SiLU-4                             [-1, 128, 28, 28]                    0
Conv2d-5                            [-1, 64, 28, 28]               73,792
Linear-6                                   [-1, 128]               16,512
GroupNorm-7                         [-1, 64, 28, 28]                  128
SiLU-8                              [-1, 64, 28, 28]                    0
Conv2d-9                            [-1, 64, 28, 28]               36,928
Conv2d-10                           [-1, 64, 28, 28]               36,928
GELU-11                             [-1, 64, 28, 28]                    0
GroupNorm-12                        [

In [11]:
tail_conv = nn.Conv2d(features, out_channels, kernel_size=3, padding=1)
out=tail_conv(dec01_x)

info=summary(tail_conv,input_data=dec01_x)

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                             [-1, 1, 28, 28]                  577
Total params: 577
Trainable params: 577
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000076
Forward/backward pass size (MB): 0.005981
Params size (MB): 0.002201
Estimated Total Size (MB): 0.008259
--------------------------------------------------------------------------
