In [1]:
import math
from dataclasses import dataclass
from numbers import Number
from typing import NamedTuple, Tuple, Union

import numpy as np
import torch as th
from torch import nn

from config_base import BaseConfig

In [5]:
import torch

a,b,c,d = torch.randn(1, 3, 32, 32).size()
a

1

In [2]:
from nn import (conv_nd, linear, normalization, timestep_embedding,
                 torch_checkpoint, zero_module)

from blocks import *

In [3]:
@dataclass
class BeatGANsUNetConfig(BaseConfig):
    image_size: int = 64
    in_channels: int = 3
    # base channels, will be multiplied
    model_channels: int = 64
    # output of the unet
    # suggest: 3
    # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
    out_channels: int = 3
    # how many repeating resblocks per resolution
    # the decoding side would have "one more" resblock
    # default: 2
    num_res_blocks: int = 2
    # you can also set the number of resblocks specifically for the input blocks
    # default: None = above
    num_input_res_blocks: int = None
    # number of time embed channels and style channels
    embed_channels: int = 512
    # at what resolutions you want to do self-attention of the feature maps
    # attentions generally improve performance
    # default: [16]
    # beatgans: [32, 16, 8]
    attention_resolutions: Tuple[int] = (16, )
    # number of time embed channels
    time_embed_channels: int = None
    # dropout applies to the resblocks (on feature maps)
    dropout: float = 0.1
    channel_mult: Tuple[int] = (1, 2, 4, 8)
    input_channel_mult: Tuple[int] = None
    conv_resample: bool = True
    # always 2 = 2d conv
    dims: int = 2
    # don't use this, legacy from BeatGANs
    num_classes: int = None
    use_checkpoint: bool = False
    # number of attention heads
    num_heads: int = 1
    # or specify the number of channels per attention head
    num_head_channels: int = -1
    # what's this?
    num_heads_upsample: int = -1
    # use resblock for upscale/downscale blocks (expensive)
    # default: True (BeatGANs)
    resblock_updown: bool = True
    # never tried
    use_new_attention_order: bool = False
    resnet_two_cond: bool = False
    resnet_cond_channels: int = None
    # init the decoding conv layers with zero weights, this speeds up training
    # default: True (BeattGANs)
    resnet_use_zero_module: bool = True
    # gradient checkpoint the attention operation
    attn_checkpoint: bool = False

    def make_model(self):
        return BeatGANsUNetModel(self)

In [4]:

class BeatGANsEncoder(nn.Module):
    def __init__(self, conf: BeatGANsUNetConfig):
        super().__init__()
        self.conf = conf

        if conf.num_heads_upsample == -1:
            self.num_heads_upsample = conf.num_heads

        self.dtype = th.float32

        self.time_emb_channels = conf.time_embed_channels or conf.model_channels
        self.time_embed = nn.Sequential(
            linear(self.time_emb_channels, conf.embed_channels),
            nn.SiLU(),
            linear(conf.embed_channels, conf.embed_channels),
        )

        if conf.num_classes is not None:
            self.label_emb = nn.Embedding(conf.num_classes,
                                          conf.embed_channels)

        ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
        self.input_blocks = nn.ModuleList([
            TimestepEmbedSequential(
                conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
        ])

        kwargs = dict(
            use_condition=True,
            two_cond=conf.resnet_two_cond,
            use_zero_module=conf.resnet_use_zero_module,
            # style channels for the resnet block
            cond_emb_channels=conf.resnet_cond_channels,
        )

        self._feature_size = [ch]

        # input_block_chans = [ch]
        input_block_chans = [[] for _ in range(len(conf.channel_mult))]
        input_block_chans[0].append(ch)

        # number of blocks at each resolution
        self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
        self.input_num_blocks[0] = 1
        self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]

        ds = 1
        resolution = conf.image_size
        for level, mult in enumerate(conf.input_channel_mult
                                     or conf.channel_mult):
            for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
                layers = [
                    ResBlockConfig(
                        ch,
                        conf.embed_channels,
                        conf.dropout,
                        out_channels=int(mult * conf.model_channels),
                        dims=conf.dims,
                        use_checkpoint=conf.use_checkpoint,
                        **kwargs,
                    ).make_model()
                ]
                ch = int(mult * conf.model_channels)
                # if resolution in conf.attention_resolutions:
                #     layers.append(
                #         AttentionBlock(
                #             ch,
                #             use_checkpoint=conf.use_checkpoint
                #             or conf.attn_checkpoint,
                #             num_heads=conf.num_heads,
                #             num_head_channels=conf.num_head_channels,
                #             use_new_attention_order=conf.
                #             use_new_attention_order,
                #         ))
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size.append(ch)
                # input_block_chans.append(ch)
                input_block_chans[level].append(ch)
                self.input_num_blocks[level] += 1
                # print(input_block_chans)
            if level != len(conf.channel_mult) - 1:
                resolution //= 2
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlockConfig(
                            ch,
                            conf.embed_channels,
                            conf.dropout,
                            out_channels=out_ch,
                            dims=conf.dims,
                            use_checkpoint=conf.use_checkpoint,
                            down=True,
                            **kwargs,
                        ).make_model() if conf.
                        resblock_updown else Downsample(ch,
                                                        conf.conv_resample,
                                                        dims=conf.dims,
                                                        out_channels=out_ch)))
                ch = out_ch
                # input_block_chans.append(ch)
                input_block_chans[level + 1].append(ch)
                self.input_num_blocks[level + 1] += 1
                ds *= 2
                self._feature_size.append(ch)

        # self._to_vector_layers = [nn.Sequential(
        #         normalization(ch),
        #         nn.SiLU(),
        #         nn.AdaptiveAvgPool2d((1, 1)),
        #         conv_nd(conf.dims, ch, ch, 1),
        #         nn.Flatten(),
        #         ).cuda() for ch in self._feature_size]

    def forward(self, x, t=None, y=None, **kwargs):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        # hs = []
        hs = [[] for _ in range(len(self.conf.channel_mult))]
        #emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))

        if self.conf.num_classes is not None:
            raise NotImplementedError()
            # assert y.shape == (x.shape[0], )
            # emb = emb + self.label_emb(y)

        # new code supports input_num_blocks != output_num_blocks
        h = x.type(self.dtype)
        k = 0
        results = []
        for i in range(len(self.input_num_blocks)):
            for j in range(self.input_num_blocks[i]):
                h = self.input_blocks[k](h, emb=None)
                # print(i, j, h.shape)
                hs[i].append(h)
                results.append(h)
                #print (h.shape)
                k += 1
        assert k == len(self.input_blocks)

        # vectors = []

        # for i, feat in enumerate(results):
        #     vectors.append(self._to_vector_layers[i](feat))

        return results

In [6]:
@dataclass
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
    # number of style channels
    enc_out_channels: int = 512
    enc_attn_resolutions: Tuple[int] = None
    enc_pool: str = 'depthconv'
    enc_num_res_block: int = 2
    enc_channel_mult: Tuple[int] = None
    enc_grad_checkpoint: bool = False
    latent_net_conf= None


def get_model_conf():

    return BeatGANsAutoencConfig(image_size=256, 
    in_channels=3,
    model_channels=128, 
    out_channels=3*2,  # also learns sigma
    num_res_blocks=2, 
    num_input_res_blocks=None, 
    embed_channels=512, 
    attention_resolutions=(32, 16, 8,), 
    time_embed_channels=None, 
    dropout=0.1, 
    channel_mult=(1, 1, 2, 2, 4, 4), 
    input_channel_mult=None, 
    conv_resample=True, 
    dims=2, 
    num_classes=None, 
    use_checkpoint=False,
    num_heads=1, 
    num_head_channels=-1, 
    num_heads_upsample=-1, 
    resblock_updown=True, 
    use_new_attention_order=False, 
    resnet_two_cond=True, 
    resnet_cond_channels=None, 
    resnet_use_zero_module=True, 
    attn_checkpoint=False, 
    enc_out_channels=512, 
    enc_attn_resolutions=None, 
    enc_pool='adaptivenonzero', 
    enc_num_res_block=2, 
    enc_channel_mult=(1, 1, 2, 2, 4, 4, 4), 
    enc_grad_checkpoint=False, )
    # latent_net_conf=None)

In [7]:
BeatGANsAutoencConfig().latent_net_conf

In [8]:
cfg = get_model_conf()

In [9]:
be = BeatGANsEncoder(cfg)

In [10]:
import torch

In [11]:
a = be(torch.randn(1, 3, 512, 512))

In [12]:
for i, t in enumerate(a):
    print(i, t.shape)

0 torch.Size([1, 128, 512, 512])
1 torch.Size([1, 128, 512, 512])
2 torch.Size([1, 128, 512, 512])
3 torch.Size([1, 128, 256, 256])
4 torch.Size([1, 128, 256, 256])
5 torch.Size([1, 128, 256, 256])
6 torch.Size([1, 128, 128, 128])
7 torch.Size([1, 256, 128, 128])
8 torch.Size([1, 256, 128, 128])
9 torch.Size([1, 256, 64, 64])
10 torch.Size([1, 256, 64, 64])
11 torch.Size([1, 256, 64, 64])
12 torch.Size([1, 256, 32, 32])
13 torch.Size([1, 512, 32, 32])
14 torch.Size([1, 512, 32, 32])
15 torch.Size([1, 512, 16, 16])
16 torch.Size([1, 512, 16, 16])
17 torch.Size([1, 512, 16, 16])


In [13]:
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
import os
os.chdir('../../')

In [14]:
from src.diffusers.models.unet_2d_base import UNet2DBaseModel

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
u = UNet2DBaseModel()

In [16]:
u(torch.randn(1, 3, 256, 256), 3)