In [1]:
import collections
import copy
import sys
import time
from random import seed

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from torch import optim

import dataset
import evaluation
from GaussianDiffusion import GaussianDiffusionModel, get_beta_schedule
from helpers import *
# from UNet import UNetModel, update_ema_params

torch.cuda.empty_cache()

ROOT_DIR = "./"
from PIL import Image

In [2]:
args = {
  "img_size": [
    256,
    256
  ],
  "Batch_Size": 16,
  "EPOCHS": 3000,
  "T": 1000,
  "base_channels": 128,
  "beta_schedule": "linear",
  "channel_mults": "",
  "loss-type": "l2",
  "loss_weight": "none",
  "train_start": True,
  "lr": 1e-4,
  "random_slice": False,
  "sample_distance": 800,
  "weight_decay": 0.0,
  "save_imgs": False,
  "save_vids": True,
  "dropout": 0,
  "attention_resolutions": "16,8",
  "num_heads": 2,
  "num_head_channels": -1,
  "noise_fn": "simplex",
  "dataset": "mri"
}

In [3]:
import math
from abc import abstractmethod

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x


class PositionalEmbedding(nn.Module):
    # PositionalEmbedding
    """
    Computes Positional Embedding of the timestep
    """

    def __init__(self, dim, scale=1):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = np.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample(nn.Module):
    def __init__(self, in_channels, use_conv, out_channels=None):
        super().__init__()
        self.channels = in_channels
        out_channels = out_channels or in_channels
        if use_conv:
            # downsamples by 1/2
            self.downsample = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
        else:
            assert in_channels == out_channels
            self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x, time_embed=None):
        assert x.shape[1] == self.channels
        return self.downsample(x)


class Upsample(nn.Module):
    def __init__(self, in_channels, use_conv, out_channels=None):
        super().__init__()
        self.channels = in_channels
        self.use_conv = use_conv
        # uses upsample then conv to avoid checkerboard artifacts
        # self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
        if use_conv:
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)

    def forward(self, x, time_embed=None):
        assert x.shape[1] == self.channels
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

In [4]:
class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(self, in_channels, n_heads=1, n_head_channels=-1):
        super().__init__()
        self.in_channels = in_channels
        self.norm = GroupNorm32(32, self.in_channels)
        if n_head_channels == -1:
            self.num_heads = n_heads
        else:
            assert (
                    in_channels % n_head_channels == 0
            ), f"q,k,v channels {in_channels} is not divisible by num_head_channels {n_head_channels}"
            self.num_heads = in_channels // n_head_channels

        # query, key, value for attention
        self.to_qkv = nn.Conv1d(in_channels, in_channels * 3, 1)
        self.attention = QKVAttention(self.num_heads)
        self.proj_out = zero_module(nn.Conv1d(in_channels, in_channels, 1))

    def forward(self, x, time=None):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.to_qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

In [6]:
# AB = AttentionBlock(512)

In [7]:
x_test = torch.rand(1,  512, 7, 7)

In [8]:
y_test = AB(x_test)

NameError: name 'AB' is not defined

In [10]:
# y_test.shape

In [63]:



class QKVAttention(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv, time=None):
        """
        Apply QKV attention.
        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum(
                "bct,bcs->bts", q * scale, k * scale
                )  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)


class ResBlock(TimestepBlock):
    def __init__(
            self,
            in_channels,
            time_embed_dim,
            dropout,
            out_channels=None,
            use_conv=False,
            up=False,
            down=False
            ):
        super().__init__()
        out_channels = out_channels or in_channels
        self.in_layers = nn.Sequential(
                GroupNorm32(32, in_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels, out_channels, 3, padding=1)
                )
        self.updown = up or down

        if up:
            self.h_upd = Upsample(in_channels, False)
            self.x_upd = Upsample(in_channels, False)
        elif down:
            self.h_upd = Downsample(in_channels, False)
            self.x_upd = Downsample(in_channels, False)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.embed_layers = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_embed_dim, out_channels)
                )
        self.out_layers = nn.Sequential(
                GroupNorm32(32, out_channels),
                nn.SiLU(),
                nn.Dropout(p=dropout),
                zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
                )
        if out_channels == in_channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        else:
            self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x, time_embed):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        emb_out = self.embed_layers(time_embed).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]

        h = h + emb_out
        h = self.out_layers(h)
        return self.skip_connection(x) + h


class UNetModel(nn.Module):
    # UNet model
    def __init__(
            self,
            img_size,
            base_channels,
            conv_resample=True,
            n_heads=1,
            n_head_channels=-1,
            channel_mults="",
            num_res_blocks=2,
            dropout=0,
            attention_resolutions="32,16,8",
            biggan_updown=True,
            in_channels=1
            ):
        self.dtype = torch.float32
        super().__init__()

        if channel_mults == "":
            if img_size == 512:
                channel_mults = (0.5, 1, 1, 2, 2, 4, 4)
            elif img_size == 256:
                channel_mults = (1, 1, 2, 2, 4, 4)
            elif img_size == 128:
                channel_mults = (1, 1, 2, 3, 4)
            elif img_size == 64:
                channel_mults = (1, 2, 3, 4)
            elif img_size == 32:
                channel_mults = (1, 2, 3, 4)
            else:
                raise ValueError(f"unsupported image size: {img_size}")
        attention_ds = []
        for res in attention_resolutions.split(","):
            attention_ds.append(img_size // int(res))

        self.image_size = img_size
        self.in_channels = in_channels
        self.model_channels = base_channels
        self.out_channels = in_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mults
        self.conv_resample = conv_resample

        self.dtype = torch.float32
        self.num_heads = n_heads
        self.num_head_channels = n_head_channels

        time_embed_dim = base_channels * 4
        self.time_embedding = nn.Sequential(
                PositionalEmbedding(base_channels, 1),
                nn.Linear(base_channels, time_embed_dim),
                nn.SiLU(),
                nn.Linear(time_embed_dim, time_embed_dim),
                )

        ch = int(channel_mults[0] * base_channels)
        self.down = nn.ModuleList(
                [TimestepEmbedSequential(nn.Conv2d(self.in_channels, base_channels, 3, padding=1))]
                )
        channels = [ch]
        ds = 1
        for i, mult in enumerate(channel_mults):
            # out_channels = base_channels * mult

            for _ in range(num_res_blocks):
                layers = [ResBlock(
                        ch,
                        time_embed_dim=time_embed_dim,
                        out_channels=base_channels * mult,
                        dropout=dropout,
                        )]
                ch = base_channels * mult
                # channels.append(ch)

                if ds in attention_ds:
                    layers.append(
                            AttentionBlock(
                                    ch,
                                    n_heads=n_heads,
                                    n_head_channels=n_head_channels,
                                    )
                            )
                self.down.append(TimestepEmbedSequential(*layers))
                channels.append(ch)
            if i != len(channel_mults) - 1:
                out_channels = ch
                self.down.append(
                        TimestepEmbedSequential(
                                ResBlock(
                                        ch,
                                        time_embed_dim=time_embed_dim,
                                        out_channels=out_channels,
                                        dropout=dropout,
                                        down=True
                                        )
                                if biggan_updown
                                else
                                Downsample(ch, conv_resample, out_channels=out_channels)
                                )
                        )
                ds *= 2
                ch = out_channels
                channels.append(ch)
#         self.MaskBlock = SSPCAB(ch, reduction_ratio=8)
        self.middle = TimestepEmbedSequential(
                ResBlock(
                        ch,
                        time_embed_dim=time_embed_dim,
                        dropout=dropout
                        ),
                SSPCAB(
                      ch,
                      reduction_ratio=8
                      ),
                ResBlock(
                        ch,
                        time_embed_dim=time_embed_dim,
                        dropout=dropout
                        )
                )
        self.up = nn.ModuleList([])

        for i, mult in reversed(list(enumerate(channel_mults))):
            for j in range(num_res_blocks + 1):
                inp_chs = channels.pop()
                layers = [
                    ResBlock(
                            ch + inp_chs,
                            time_embed_dim=time_embed_dim,
                            out_channels=base_channels * mult,
                            dropout=dropout
                            )
                    ]
                ch = base_channels * mult
                if ds in attention_ds:
                    layers.append(
                            AttentionBlock(
                                    ch,
                                    n_heads=n_heads,
                                    n_head_channels=n_head_channels
                                    ),
                            )

                if i and j == num_res_blocks:
                    out_channels = ch
                    layers.append(
                            ResBlock(
                                    ch,
                                    time_embed_dim=time_embed_dim,
                                    out_channels=out_channels,
                                    dropout=dropout,
                                    up=True
                                    )
                            if biggan_updown
                            else
                            Upsample(ch, conv_resample, out_channels=out_channels)
                            )
                    ds //= 2
                self.up.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
                GroupNorm32(32, ch),
                nn.SiLU(),
                zero_module(nn.Conv2d(base_channels * channel_mults[0], self.out_channels, 3, padding=1))
                )

    def forward(self, x, time):

        time_embed = self.time_embedding(time)
        skips = []

        h = x.type(self.dtype)
        for i, module in enumerate(self.down):
            h = module(h, time_embed)
            skips.append(h)
        
        fm = h
#         print(fm.shape)
        h_hat = self.middle(fm, time_embed)
        H = h_hat
#         print(h_hat.shape)
        for i, module in enumerate(self.up):
            H = torch.cat([H, skips.pop()], dim=1)
            H = module(H, time_embed)
        H = H.type(x.dtype)
        H = self.out(H)
        return H, fm, h_hat


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def update_ema_params(target, source, decay_rate=0.9999):
    targParams = dict(target.named_parameters())
    srcParams = dict(source.named_parameters())
    for k in targParams:
        targParams[k].data.mul_(decay_rate).add_(srcParams[k].data, alpha=1 - decay_rate)

In [64]:
# Squeeze and Excitation block
class SELayer(nn.Module):
    def __init__(self, num_channels, reduction_ratio=8):
        '''
            num_channels: The number of input channels
            reduction_ratio: The reduction ratio 'r' from the paper
        '''
        super(SELayer, self).__init__()
        num_channels_reduced = num_channels // reduction_ratio
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        batch_size, num_channels, H, W = input_tensor.size()

        squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
#         print(squeeze_tensor.shape)
        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor))
#         print(fc_out_1.shape)
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
#         print(fc_out_2.shape)
        a, b = squeeze_tensor.size()
#         print(a,b)
#         fc_out_2 = torch.reshape(fc_out_2, a, b, 1, 1)
        output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
        return output_tensor


# SSPCAB implementation
class SSPCAB(nn.Module):
    def __init__(self, channels, kernel_dim=1, dilation=1, reduction_ratio=4):
        '''
            channels: The number of filter at the output (usually the same with the number of filter from the input)
            kernel_dim: The dimension of the sub-kernels ' k' ' from the paper
            dilation: The dilation dimension 'd' from the paper
            reduction_ratio: The reduction ratio for the SE block ('r' from the paper)
        '''
        super(SSPCAB, self).__init__()
        self.pad = kernel_dim + dilation
        self.border_input = kernel_dim + 2*dilation + 1

        self.relu = nn.ReLU()
        self.se = SELayer(channels, reduction_ratio=reduction_ratio)

        self.conv1 = nn.Conv2d(in_channels=channels,
                               out_channels=channels,
                               kernel_size=kernel_dim)
        self.conv2 = nn.Conv2d(in_channels=channels,
                               out_channels=channels,
                               kernel_size=kernel_dim)
        self.conv3 = nn.Conv2d(in_channels=channels,
                               out_channels=channels,
                               kernel_size=kernel_dim)
        self.conv4 = nn.Conv2d(in_channels=channels,
                               out_channels=channels,
                               kernel_size=kernel_dim)

    def forward(self, x):
        x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), "constant", 0)

        x1 = self.conv1(x[:, :, :-self.border_input, :-self.border_input])
        x2 = self.conv2(x[:, :, self.border_input:, :-self.border_input])
        x3 = self.conv3(x[:, :, :-self.border_input, self.border_input:])
        x4 = self.conv4(x[:, :, self.border_input:, self.border_input:])
        x = self.relu(x1 + x2 + x3 + x4)

        x = self.se(x)
        return x


In [65]:
x_test = torch.rand(1, 3, 448, 448)

In [66]:
y_test = SSPCAB(3)(x_test)

In [67]:
y_test.shape

torch.Size([1, 3, 448, 448])

In [78]:
if __name__ == "__main__":
    args = {
        'img_size':          256,
        'base_channels':     64,
        'dropout':           0.3,
        'num_heads':         4,
        'num_head_channels': '32,16,8',
        'lr':                1e-4,
        'Batch_Size':        64
        }
    model = UNetModel(
            args['img_size'], args['base_channels'], dropout=args[
                "dropout"], n_heads=args["num_heads"], attention_resolutions=args["num_head_channels"],
            in_channels=3
            )

    x = torch.randn(1, 3, 512, 512)
    t_batch = torch.tensor([1], device=x.device).repeat(x.shape[0])
#     print(model(x, t_batch).shape)

torch.Size([1, 256, 16, 16])
torch.Size([1, 256, 16, 16])


AttributeError: 'tuple' object has no attribute 'shape'

In [69]:
# MIDDLE = model.middle(x, t_batch)

In [70]:
args

{'img_size': [256, 256],
 'Batch_Size': 16,
 'EPOCHS': 3000,
 'T': 1000,
 'base_channels': 128,
 'beta_schedule': 'linear',
 'channel_mults': '',
 'loss-type': 'l2',
 'loss_weight': 'none',
 'train_start': True,
 'lr': 0.0001,
 'random_slice': False,
 'sample_distance': 800,
 'weight_decay': 0.0,
 'save_imgs': False,
 'save_vids': True,
 'dropout': 0,
 'attention_resolutions': '16,8',
 'num_heads': 2,
 'num_head_channels': -1,
 'noise_fn': 'simplex',
 'dataset': 'mri'}

In [71]:
in_channels=3
unet = UNetModel(
        args['img_size'][0], args['base_channels'], channel_mults=args['channel_mults'], in_channels=in_channels
        )

In [72]:
unet

UNetModel(
  (time_embedding): Sequential(
    (0): PositionalEmbedding()
    (1): Linear(in_features=128, out_features=512, bias=True)
    (2): SiLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
  )
  (down): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (embed_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=128, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(128, 12

In [73]:
x_test = torch.rand(10, 3, 256, 256)
t_batch = torch.tensor([1], device=x_test.device).repeat(x_test.shape[0])

In [74]:
# unet.middle

In [75]:
y_test = unet(x_test, t_batch)

torch.Size([10, 512, 8, 8])
torch.Size([10, 512, 8, 8])


In [77]:
(y_test[-1]-y_test[-2]).shape

torch.Size([10, 512, 8, 8])

In [42]:
l=[1, 2, 3, 4, 5, 6, 8]
l.pop()

8

In [18]:
TE = unet.time_embedding(t_batch)

In [19]:
TE.shape

torch.Size([10, 512])

In [21]:
down = unet.down()

In [26]:
time_embed = unet.time_embedding(t_batch)
print(time_embed.shape)
skips = []
h = x_test.type(unet.dtype)
print("shape h0", h.shape)
for i, module in enumerate(unet.down):
    h = module(h, time_embed)
    print(f"shape h{i+1}", h.shape)
    skips.append(h)


torch.Size([10, 512])
shape h0 torch.Size([10, 3, 224, 224])
shape h1 torch.Size([10, 128, 224, 224])
shape h2 torch.Size([10, 128, 224, 224])
shape h3 torch.Size([10, 128, 224, 224])
shape h4 torch.Size([10, 128, 112, 112])
shape h5 torch.Size([10, 128, 112, 112])
shape h6 torch.Size([10, 128, 112, 112])
shape h7 torch.Size([10, 128, 56, 56])
shape h8 torch.Size([10, 256, 56, 56])
shape h9 torch.Size([10, 256, 56, 56])
shape h10 torch.Size([10, 256, 28, 28])
shape h11 torch.Size([10, 256, 28, 28])
shape h12 torch.Size([10, 256, 28, 28])
shape h13 torch.Size([10, 256, 14, 14])
shape h14 torch.Size([10, 512, 14, 14])
shape h15 torch.Size([10, 512, 14, 14])
shape h16 torch.Size([10, 512, 7, 7])
shape h17 torch.Size([10, 512, 7, 7])
shape h18 torch.Size([10, 512, 7, 7])


In [28]:
for i in skips:
    print(i.shape)

torch.Size([10, 128, 224, 224])
torch.Size([10, 128, 224, 224])
torch.Size([10, 128, 224, 224])
torch.Size([10, 128, 112, 112])
torch.Size([10, 128, 112, 112])
torch.Size([10, 128, 112, 112])
torch.Size([10, 128, 56, 56])
torch.Size([10, 256, 56, 56])
torch.Size([10, 256, 56, 56])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 28, 28])
torch.Size([10, 256, 14, 14])
torch.Size([10, 512, 14, 14])
torch.Size([10, 512, 14, 14])
torch.Size([10, 512, 7, 7])
torch.Size([10, 512, 7, 7])
torch.Size([10, 512, 7, 7])


In [29]:
h = unet.middle(h, time_embed)
print(h.shape)

torch.Size([10, 512, 7, 7])


In [30]:
for i, module in enumerate(unet.up):
    h = torch.cat([h, skips.pop()], dim=1)
    print(f"shape h{i}", h.shape)
    h = module(h, time_embed)
    print(f"shape H{i}", h.shape)


shape h0 torch.Size([10, 1024, 7, 7])
shape H0 torch.Size([10, 512, 7, 7])
shape h1 torch.Size([10, 1024, 7, 7])
shape H1 torch.Size([10, 512, 7, 7])
shape h2 torch.Size([10, 1024, 7, 7])
shape H2 torch.Size([10, 512, 14, 14])
shape h3 torch.Size([10, 1024, 14, 14])
shape H3 torch.Size([10, 512, 14, 14])
shape h4 torch.Size([10, 1024, 14, 14])
shape H4 torch.Size([10, 512, 14, 14])
shape h5 torch.Size([10, 768, 14, 14])
shape H5 torch.Size([10, 512, 28, 28])
shape h6 torch.Size([10, 768, 28, 28])
shape H6 torch.Size([10, 256, 28, 28])
shape h7 torch.Size([10, 512, 28, 28])
shape H7 torch.Size([10, 256, 28, 28])
shape h8 torch.Size([10, 512, 28, 28])
shape H8 torch.Size([10, 256, 56, 56])
shape h9 torch.Size([10, 512, 56, 56])
shape H9 torch.Size([10, 256, 56, 56])
shape h10 torch.Size([10, 512, 56, 56])
shape H10 torch.Size([10, 256, 56, 56])
shape h11 torch.Size([10, 384, 56, 56])
shape H11 torch.Size([10, 256, 112, 112])
shape h12 torch.Size([10, 384, 112, 112])
shape H12 torch.Size(

In [31]:
h = h.type(x_test.dtype)
print(h.shape)
h = unet.out(h)
print(h.shape)

torch.Size([10, 128, 224, 224])
torch.Size([10, 3, 224, 224])


In [15]:
y_test = unet(x_test, t_batch)

In [16]:
y_test.shape

torch.Size([10, 3, 224, 224])

In [14]:
t_batch

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [33]:
# unet.down

In [34]:
unet.down

ModuleList(
  (0): TimestepEmbedSequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (embed_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=512, out_features=128, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
  )
  (2): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
        (1): SiLU()
  

In [35]:
unet.middle

TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 512, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (embed_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=512, out_features=512, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 512, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): AttentionBlock(
    (norm): GroupNorm32(32, 512, eps=1e-05, affine=True)
    (to_qkv): Conv1d(512, 1536, kernel_size=(1,), stride=(1,))
    (attention): QKVAttention()
    (proj_out): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
  )
  (2): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 512, eps=1e-05, affi

In [36]:
unet.out

Sequential(
  (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
  (1): SiLU()
  (2): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [None]:
    def forward(self, x, time):

        time_embed = self.time_embedding(time)

        skips = []

        h = x.type(self.dtype)
        for i, module in enumerate(self.down):
            h = module(h, time_embed)
            skips.append(h)
        h = self.middle(h, time_embed)
        for i, module in enumerate(self.up):
            h = torch.cat([h, skips.pop()], dim=1)
            h = module(h, time_embed)
        h = h.type(x.dtype)
        h = self.out(h)
        return h