In [1]:
import torch
import torch.nn as nn
import sys
sys.path.append('/home/aiteam/tykim/cubox/diffusers/src')


import os
os.chdir('../../')

In [2]:
from src.diffusers.models.unet_2d_blocks import *

In [38]:
from src.diffusers.models.activations import get_activation

In [42]:
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor

In [49]:
class AuxDecoder(nn.Module):
    def __init__(self, 
                 up_block_types = ["UpDecoderBlockTimeless2D", "AttnUpDecoderBlockTimeless2D", "UpDecoderBlockTimeless2D"],
                 block_out_channels = [64, 128, 256, 512],
                 layers_per_block = 2,
                 norm_eps: float = 1e-5,
                 act_fn: str = "silu",
                 norm_num_groups: int = 32,
                 attention_head_dim: Optional[int] = 8
                 ):
        super().__init__()
        reversed_block_out_channels = list(reversed(block_out_channels)) # [512, 256, 128, 64]
        self.up_blocks = nn.ModuleList([])
        for i, up_block_type in enumerate(up_block_types):
            input_channel = reversed_block_out_channels[i]
            output_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel = None,
                temb_channels = None,
                add_upsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
            )
            self.up_blocks.append(up_block)

        # out
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )

            self.conv_act = get_activation(act_fn)

        else:
            self.conv_norm_out = None
            self.conv_act = None
            
        conv_out_kernel = 3
        out_channels = 3

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )
    def forward(self,
                sample: torch.FloatTensor,
                return_dict: bool = True):
        for upsample_block in self.up_blocks:
            sample = upsample_block(sample)
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if not return_dict:
            return (sample,)

        return sample
        #return UNet2DConditionOutput(sample=sample)

In [50]:
ad =AuxDecoder()

In [52]:
ad(torch.randn(1, 512, 32, 32)).shape

torch.Size([1, 3, 256, 256])