### Resdual block for diffusion models

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from module import (
    ResBlock
)
from dataset import mnist
from util import get_torch_size_string,plot_4x4_torch_tensor
np.set_printoptions(precision=3)
th.set_printoptions(precision=3)
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(th.__version__))

PyTorch version:[2.0.1].


### 1-D case `[B x C x L]`

In [5]:
# Input
x = th.randn(16,32,200) # [B x C x L]
emb = th.randn(16,128) # [B x n_emb_channels]

print ("1. No upsample nor downsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 1,
    upsample       = True,
    downsample     = False,
    down_rate      = 1
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("2. Upsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 1,
    upsample       = True,
    downsample     = False,
    down_rate      = 2
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("3. Downsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 1,
    upsample       = False,
    downsample     = True,
    down_rate      = 2
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

1. No upsample nor downsample
 Shape x:[16x32x200] => out:[16x32x400]
2. Upsample
 Shape x:[16x32x200] => out:[16x32x400]
3. Downsample
 Shape x:[16x32x200] => out:[16x32x100]


### 2-D case `[B x C x W x H]`

In [3]:
# Input
x = th.randn(16,32,28,28) # [B x C x W x H]
emb = th.randn(16,128) # [B x n_emb_channels]

print ("1. No upsample nor downsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 2,
    upsample       = False,
    downsample     = False
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("2. Upsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 2,
    upsample       = True,
    downsample     = False
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("3. Downsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 2,
    upsample       = False,
    downsample     = True
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("4. (uneven) Upsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 2,
    upsample       = True,
    downsample     = False,
    up_rate        = (2,1)
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("5. (uneven) Downsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 2,
    upsample       = False,
    downsample     = True,
    down_rate      = (2,1)
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

print ("6. (fake) Downsample")
resblock = ResBlock(
    n_channels     = 32,
    n_emb_channels = 128,
    n_out_channels = 32,
    n_groups       = 16,
    dims           = 2,
    upsample       = False,
    downsample     = True,
    down_rate      = (1,1)
)
out = resblock(x,emb)
print (" Shape x:[%s] => out:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

1. No upsample nor downsample
 Shape x:[16x32x28x28] => out:[16x32x28x28]
2. Upsample
 Shape x:[16x32x28x28] => out:[16x32x56x56]
3. Downsample
 Shape x:[16x32x28x28] => out:[16x32x14x14]
4. (uneven) Upsample
 Shape x:[16x32x28x28] => out:[16x32x56x28]
5. (uneven) Downsample
 Shape x:[16x32x28x28] => out:[16x32x14x28]
6. (fake) Downsample
 Shape x:[16x32x28x28] => out:[16x32x28x28]
