### U-net Lagacy

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
from util import (
    get_torch_size_string
)
from diffusion import (
    get_ddpm_constants,
    plot_ddpm_constants,
    DiffusionUNet,
    DiffusionUNetLegacy
)
from dataset import mnist
np.set_printoptions(precision=3)
th.set_printoptions(precision=3)
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(th.__version__))

PyTorch version:[2.0.1].


In [2]:
device = 'mps'
print ("device:[%s]"%(device))

device:[mps]


In [3]:
dc = get_ddpm_constants(
    schedule_name = 'cosine', # 'linear', 'cosine'
    T             = 1000,
    np_type       = np.float32)
print("Ready.") 

Ready.


### Guided U-net
<img src="../img/unet.jpg" width="500" />

### 1-D case: `[B x C x L]` with attention

In [4]:
unet = DiffusionUNetLegacy(
    name             = 'unet',
    dims             = 1,
    n_in_channels    = 3,
    n_base_channels  = 32,
    n_emb_dim        = 128,
    n_enc_blocks     = 4, # number of encoder blocks
    n_dec_blocks     = 4, # number of decoder blocks
    n_groups         = 16, # group norm paramter
    use_attention    = True,
    skip_connection  = True, # additional skip connection
    chnnel_multiples = (1,2,4,8),
    updown_rates     = (2,2,2,2),
    device           = device,
)
# Inputs, timesteps:[B] and x:[B x C x L]
batch_size = 2
x = th.randn(batch_size,3,256).to(device) # [B x C x L]
timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]
out,intermediate_output_dict = unet(x,timesteps)
print ("Input: x:[%s] timesteps:[%s]"%(
    get_torch_size_string(x),get_torch_size_string(timesteps)
))
print ("Output: out:[%s]"%(get_torch_size_string(out)))
# Print intermediate layers
for k_idx,key in enumerate(intermediate_output_dict.keys()):
    z = intermediate_output_dict[key]
    print ("[%2d] key:[%12s] shape:[%12s]"%(k_idx,key,get_torch_size_string(z)))

Input: x:[2x3x256] timesteps:[2]
Output: out:[2x3x256]
[ 0] key:[           x] shape:[     2x3x256]
[ 1] key:[    x_lifted] shape:[    2x32x256]
[ 2] key:[h_enc_res_00] shape:[    2x32x128]
[ 3] key:[h_enc_att_01] shape:[    2x32x128]
[ 4] key:[h_enc_res_02] shape:[     2x64x64]
[ 5] key:[h_enc_att_03] shape:[     2x64x64]
[ 6] key:[h_enc_res_04] shape:[    2x128x32]
[ 7] key:[h_enc_att_05] shape:[    2x128x32]
[ 8] key:[h_enc_res_06] shape:[    2x256x16]
[ 9] key:[h_enc_att_07] shape:[    2x256x16]
[10] key:[h_dec_res_00] shape:[    2x256x32]
[11] key:[h_dec_att_01] shape:[    2x256x32]
[12] key:[h_dec_res_02] shape:[    2x128x64]
[13] key:[h_dec_att_03] shape:[    2x128x64]
[14] key:[h_dec_res_04] shape:[    2x64x128]
[15] key:[h_dec_att_05] shape:[    2x64x128]
[16] key:[h_dec_res_06] shape:[    2x32x256]
[17] key:[h_dec_att_07] shape:[    2x32x256]
[18] key:[         out] shape:[     2x3x256]


### 1-D case: `[B x C x L]` without attention

In [5]:
unet = DiffusionUNetLegacy(
    name             = 'unet',
    dims             = 1,
    n_in_channels    = 3,
    n_base_channels  = 32,
    n_emb_dim        = 128,
    n_enc_blocks     = 4, # number of encoder blocks
    n_dec_blocks     = 4, # number of decoder blocks
    n_groups         = 16, # group norm paramter
    use_attention    = False,
    skip_connection  = True, # additional skip connection
    chnnel_multiples = (1,2,4,8),
    updown_rates     = (2,2,2,2),
    device           = device,
)
# Inputs, timesteps:[B] and x:[B x C x L]
batch_size = 2
x = th.randn(batch_size,3,256).to(device) # [B x C x L]
timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]
out,intermediate_output_dict = unet(x,timesteps)
print ("Input: x:[%s] timesteps:[%s]"%(
    get_torch_size_string(x),get_torch_size_string(timesteps)
))
print ("Output: out:[%s]"%(get_torch_size_string(out)))
# Print intermediate layers
for k_idx,key in enumerate(intermediate_output_dict.keys()):
    z = intermediate_output_dict[key]
    print ("[%2d] key:[%12s] shape:[%12s]"%(k_idx,key,get_torch_size_string(z)))

Input: x:[2x3x256] timesteps:[2]
Output: out:[2x3x256]
[ 0] key:[           x] shape:[     2x3x256]
[ 1] key:[    x_lifted] shape:[    2x32x256]
[ 2] key:[h_enc_res_00] shape:[    2x32x128]
[ 3] key:[h_enc_res_01] shape:[     2x64x64]
[ 4] key:[h_enc_res_02] shape:[    2x128x32]
[ 5] key:[h_enc_res_03] shape:[    2x256x16]
[ 6] key:[h_dec_res_00] shape:[    2x256x32]
[ 7] key:[h_dec_res_01] shape:[    2x128x64]
[ 8] key:[h_dec_res_02] shape:[    2x64x128]
[ 9] key:[h_dec_res_03] shape:[    2x32x256]
[10] key:[         out] shape:[     2x3x256]


### 2-D case: `[B x C x W x H]` without attention

In [6]:
unet = DiffusionUNetLegacy(
    name             = 'unet',
    dims             = 2,
    n_in_channels    = 3,
    n_base_channels  = 32,
    n_emb_dim        = 128,
    n_enc_blocks     = 3, # number of encoder blocks
    n_dec_blocks     = 3, # number of decoder blocks
    n_groups         = 16, # group norm paramter
    use_attention    = False,
    skip_connection  = True, # additional skip connection
    chnnel_multiples = (1,2,4),
    updown_rates     = (1,1,1),
    device           = device,
)
# Inputs, timesteps:[B] and x:[B x C x W x H]
batch_size = 2
x = th.randn(batch_size,3,256,256).to(device) # [B x C x W x H]
timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]
out,intermediate_output_dict = unet(x,timesteps)
print ("Input: x:[%s] timesteps:[%s]"%(
    get_torch_size_string(x),get_torch_size_string(timesteps)
))
print ("Output: out:[%s]"%(get_torch_size_string(out)))
# Print intermediate layers
for k_idx,key in enumerate(intermediate_output_dict.keys()):
    z = intermediate_output_dict[key]
    print ("[%2d] key:[%12s] shape:[%12s]"%(k_idx,key,get_torch_size_string(z)))

Input: x:[2x3x256x256] timesteps:[2]
Output: out:[2x3x256x256]
[ 0] key:[           x] shape:[ 2x3x256x256]
[ 1] key:[    x_lifted] shape:[2x32x256x256]
[ 2] key:[h_enc_res_00] shape:[2x32x256x256]
[ 3] key:[h_enc_res_01] shape:[2x64x256x256]
[ 4] key:[h_enc_res_02] shape:[2x128x256x256]
[ 5] key:[h_dec_res_00] shape:[2x128x256x256]
[ 6] key:[h_dec_res_01] shape:[2x64x256x256]
[ 7] key:[h_dec_res_02] shape:[2x32x256x256]
[ 8] key:[         out] shape:[ 2x3x256x256]


### 2-D case: `[B x C x W x H]` without attention + updown pooling

In [7]:
unet = DiffusionUNetLegacy(
    name             = 'unet',
    dims             = 2,
    n_in_channels    = 3,
    n_base_channels  = 32,
    n_emb_dim        = 128,
    n_enc_blocks     = 3, # number of encoder blocks
    n_dec_blocks     = 3, # number of decoder blocks
    n_groups         = 16, # group norm paramter
    use_attention    = False,
    skip_connection  = True, # additional skip connection
    chnnel_multiples = (1,2,4),
    updown_rates     = (1,2,2),
    device           = device,
)
# Inputs, timesteps:[B] and x:[B x C x W x H]
batch_size = 2
x = th.randn(batch_size,3,256,256).to(device) # [B x C x W x H]
timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]
out,intermediate_output_dict = unet(x,timesteps)
print ("Input: x:[%s] timesteps:[%s]"%(
    get_torch_size_string(x),get_torch_size_string(timesteps)
))
print ("Output: out:[%s]"%(get_torch_size_string(out)))
# Print intermediate layers
for k_idx,key in enumerate(intermediate_output_dict.keys()):
    z = intermediate_output_dict[key]
    print ("[%2d] key:[%12s] shape:[%12s]"%(k_idx,key,get_torch_size_string(z)))

Input: x:[2x3x256x256] timesteps:[2]
Output: out:[2x3x256x256]
[ 0] key:[           x] shape:[ 2x3x256x256]
[ 1] key:[    x_lifted] shape:[2x32x256x256]
[ 2] key:[h_enc_res_00] shape:[2x32x256x256]
[ 3] key:[h_enc_res_01] shape:[2x64x128x128]
[ 4] key:[h_enc_res_02] shape:[ 2x128x64x64]
[ 5] key:[h_dec_res_00] shape:[2x128x128x128]
[ 6] key:[h_dec_res_01] shape:[2x64x256x256]
[ 7] key:[h_dec_res_02] shape:[2x32x256x256]
[ 8] key:[         out] shape:[ 2x3x256x256]
