In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, 'denoising-diffusion-pytorch/')

In [3]:
from sinfusion.convnext import Block as ConvNextBlock

In [4]:
from torchinfo import summary

In [5]:
block = ConvNextBlock(64)
summary(block, input_size=(1, 64, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
Block                                    [1, 64, 32, 32]           64
├─Conv2d: 1-1                            [1, 64, 32, 32]           3,200
├─LayerNorm: 1-2                         [1, 32, 32, 64]           128
├─Linear: 1-3                            [1, 32, 32, 256]          16,640
├─GELU: 1-4                              [1, 32, 32, 256]          --
├─Linear: 1-5                            [1, 32, 32, 64]           16,448
├─Identity: 1-6                          [1, 64, 32, 32]           --
Total params: 36,480
Trainable params: 36,480
Non-trainable params: 0
Total mult-adds (M): 3.31
Input size (MB): 0.26
Forward/backward pass size (MB): 3.67
Params size (MB): 0.15
Estimated Total Size (MB): 4.08

In [34]:
x = torch.randn(2, 3, 10, 10)
repeats = list(x.shape)
repeats[1] = 1
y = torch.randn(3).view(1, -1, 1, 1).repeat(repeats)
torch.cat([x, y], dim=1)

tensor([[[[ 3.4300e-01,  6.7156e-01, -5.4269e-01,  ...,  4.2047e-01,
            7.6600e-01,  1.0213e+00],
          [ 5.9349e-01, -1.2021e+00, -2.3169e-01,  ...,  4.9485e-01,
            7.1561e-01,  6.8871e-01],
          [ 5.0260e-02, -1.0741e+00, -1.5700e+00,  ...,  1.3138e+00,
            8.3342e-02,  5.7727e-01],
          ...,
          [-1.7141e+00, -1.7083e+00, -1.0414e+00,  ...,  9.0297e-02,
           -1.0634e+00, -2.0310e-01],
          [ 1.3243e+00, -8.9916e-01,  3.5622e-01,  ...,  2.3185e+00,
            1.0020e+00,  9.6531e-01],
          [-6.8497e-01,  1.0265e+00,  2.1010e-01,  ...,  4.8368e-01,
           -9.2023e-02,  1.4771e-01]],

         [[-1.1475e+00,  9.4339e-01, -5.0664e-02,  ..., -5.3286e-01,
            1.0466e+00, -1.6284e+00],
          [-5.2903e-02, -5.4705e-01,  1.3678e+00,  ...,  1.0006e+00,
            1.3034e-01,  5.8654e-02],
          [-1.3611e+00,  2.0250e-01, -2.8736e-01,  ...,  2.6533e-01,
           -8.4345e-01,  4.5877e-01],
          ...,
     

In [25]:
import torch
import math

from torch import nn

from sinfusion.utils import default
from einops import rearrange, reduce

from functools import partial

In [26]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class RandomOrLearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim, is_random = False):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered
    
x = torch.randn(64)
SinusoidalPosEmb(64)(x)
RandomOrLearnedSinusoidalPosEmb(64)(x)

tensor([[-0.1732, -0.3698,  0.6612,  ...,  0.6055, -0.3772, -0.0862],
        [-0.9253, -0.8990, -0.6584,  ...,  0.2041, -0.5104, -0.8417],
        [ 0.6976,  0.9990, -0.2290,  ..., -0.8438, -0.0327,  0.9238],
        ...,
        [ 0.6408,  0.9857, -0.4513,  ..., -0.9653,  0.5729,  0.9886],
        [-0.7342, -0.9994,  0.0785,  ..., -0.7244, -0.4311,  0.7367],
        [ 0.1889,  0.4016, -0.7090,  ...,  0.5369, -0.5351, -0.2346]],
       grad_fn=<CatBackward0>)

In [27]:
class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        self_condition = False,
        # resnet_block_groups = 8,
        learned_variance = False,
        learned_sinusoidal_cond = False,
        random_fourier_features = False,
        learned_sinusoidal_dim = 16,
        drop_path=0.0,
        layer_scale_init_value=1e-6
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # block_klass = partial(ResnetBlock, groups = resnet_block_groups)
        block_klass = partial(ConvNextBlock, drop_path=drop_path, layer_scale_init_value=layer_scale_init_value)

        # time embeddings

        time_dim = dim * 4

        self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features

        if self.random_or_learned_sinusoidal_cond:
            sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
            fourier_dim = learned_sinusoidal_dim + 1
        else:
            sinu_pos_emb = SinusoidalPosEmb(dim)
            fourier_dim = dim

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            # self.downs.append(nn.ModuleList([
            #     block_klass(dim_in, dim_in, time_emb_dim = time_dim),
            #     block_klass(dim_in, dim_in, time_emb_dim = time_dim),
            #     Residual(PreNorm(dim_in, LinearAttention(dim_in))),
            #     Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            # ]))
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            # self.ups.append(nn.ModuleList([
            #     block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
            #     block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
            #     Residual(PreNorm(dim_out, LinearAttention(dim_out))),
            #     Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            # ]))

        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond = None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim = 1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2 in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            h.append(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2 in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

    
net = Unet(64, dim_mults=(1, 2, 4, 8))
net

TypeError: __init__() got multiple values for argument 'drop_path'

In [None]:
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import Unet

import torch


net = Unet(64)
summary(net, input_size=(1, 3, 32, 32), time=torch.randn(1))