In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from inspect import isfunction
from functools import partial
from torch import einsum
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

In [2]:
class ResBlock(nn.Module):
    def __init__(
        self, 
        *, 
        in_channels, 
        out_channels=None, 
        time_emb_dim=None, 
        groups=8, 
        eps=1e-6, 
        dropout = 0.0,
        time_embedding_norm="scale_shift"
    ):
        super().__init__()

        self.time_embedding_norm = time_embedding_norm

        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1)

        if time_emb_dim is not None:
            if self.time_embedding_norm == "default":
                self.time_emb_proj = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(time_emb_dim, out_channels)
                )
            elif self.time_embedding_norm == "scale_shift":
                self.time_emb_proj = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(time_emb_dim, out_channels * 2)
                )
                
        else:
            self.time_emb_proj = None

        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1)
        self.act = nn.SiLU()
            
        self.conv_shortcut = None
        if self.use_shortcut:
            self.conv_shortcut = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                padding=0,
                stride=1,
                bias=True
            )

    def forward(self, x, temb=None):
        """
        x : [B, C, H, W]
        temb : [B, time_emb_dim]
        """
        hidden_states = x

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.conv1(hidden_states)

        if self.time_emb_proj is not None:
            temb = self.time_emb_proj(temb)[:,:,None,None]

        if self.time_embedding_norm == "default":
            if temb is not None:
                hidden_states = hidden_states + temb
            hidden_states = self.norm2(hidden_states)
        elif self.time_embedding_norm == "scale_shift":
            time_scale, time_shift = torch.chunk(temb, 2, dim=1)
            hidden_states = (1 + time_scale) * hidden_states + time_shift
        else:
            hidden_states = self.norm2(hidden_states)

        hidden_states = self.act(hidden_states)

        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            x = self.conv_shortcut(x.contiguous())

        output = (x + hidden_states)
        return output

In [None]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels=None, down_sample="full"):
        super().__init__()

        if out_channels is None:
            out_channels = in_channels

        if down_sample == 'full':
            # No More Strided Convolutions or Pooling
            self.conv = nn.Sequential(
                Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
                nn.Conv2d(in_channels * 4, out_channels, kernel_size=1)
            )
        elif dowm_sampe == 'padding':
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        else:
            self.conv = nn.AvgPool2d(kernel_size=2, stride=2)
    
    def forward(self, x, output_size=None):
        return self.conv(x)

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels=None, interpolate=False):
        super().__init__()

        self.interpolate = interpolate

        if interpolate:
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=True)
        else:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=True)
            )

    def forward(self, x, output_size=None):
        if self.interpolate:
            if output_size is None:
                x = F.interpolate(x, scale_factor=2, mode="nearest")
            else:
                x = F.interpolate(x, size=output_size, mode="nearest")

            out = self.conv(x)
        else:
            out = self.conv(x)
        return out


In [8]:
i = torch.randn(4, 5, 6)
x,y,z = i.shape
print(type(x))

<class 'int'>


In [10]:
a = torch.randint(0, 10, (3, 3))
a

tensor([[8, 8, 5],
        [2, 3, 8],
        [6, 1, 5]])

In [11]:
a.amax(dim=-1, keepdim=True)

tensor([[8],
        [8],
        [6]])

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels, n_heads, dim_head):
        super().__init__()

        self.n_heads = n_heads
        hidden_dim = dim_head * n_heads
        self.scale = dim_head ** -0.5
        self.qkv = nn.Conv2d(in_channels, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, in_channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.n_heads), qkv
        )

        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q , k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


In [13]:
a = torch.randn(10)
a.view(2, 5).softmax(-1)

tensor([[0.0279, 0.0484, 0.7742, 0.0836, 0.0659],
        [0.3463, 0.4633, 0.0170, 0.1600, 0.0134]])

In [14]:
nn.GroupNorm.__init__?

[0;31mSignature:[0m
[0mnn[0m[0;34m.[0m[0mGroupNorm[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_groups[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meps[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m1e-05[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maffine[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdevice[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Initialize internal Module state, shared by both nn.Module and ScriptModule.
[0;31mFile:[0m      /opt/conda/lib/python3.10/site-packages/torch/nn/modules/normalization.py
[0;31mType:[0m      fun

In [None]:
class LinearSelfAttention(nn.Module):
    def __init__(self, in_channels, n_heads, dim_head):
        super().__init__()

        self.n_heads = n_heads
        hidden_dim = dim_head * n_heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Conv2d(in_channels, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, in_channels, 1),
            nn.GroupNorm(1, in_channels)
        )
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -) b h c (x y)", h=self.n_heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum(context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.n_heads, x=h, y=w)
        out = self.to_out(out)

        return out

In [None]:
class SelfAttnDownBlocak(nn.Module):
    def __init__(
        self,
        num_layers,
        in_channels, 
        time_emb_dim,
        out_channels=None, 
        down_sample="full"
    ):
        super().__init__()

        for i in range(num_layers):
            resnets.append(
                ResBlock(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    time_emb_dim=

                )
            )
            attentions.append(
                SelfAttention2D(
                    out_channels,
                    heads=4,
                    dim_head=32
                )

            )

        self.resnets = nn.ModuleList(resnets)
        self.attentions = nn.ModuleList(attentions)
        self.downsample = Downsample(out_channels, out_channels, down_sample)

    def forward(self, x, temb):
        hidden_states = x
        for resnet, attn in zip(self.resnets, self.attentions):
            hidden_states = resnet(hidden_states, temb)
            hidden_states = attn(hidden_states)

