In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pytorch_wavelets import DWTInverse, DWTForward

from math import floor, ceil
from einops import rearrange
from typing import Optional, List, Tuple

from PIL import Image
import lovely_tensors as lt
lt.monkey_patch()
img = Image.open('/home/aiteam/tykim/generative/gan/2291_04.png')
from torchvision.transforms import ToTensor,ToPILImage

In [2]:
def conv_bn(in_channels, out_channels, stride, padding, groups=1):
    result = nn.Sequential()
    result.add_module('conv3x3', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                        kernel_size=3, stride=stride, padding=padding, groups=groups,bias=False))
    result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    result.add_module('conv1x1', nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 
                                        kernel_size=1, stride=stride, padding=padding, groups=groups,bias=False))

    return result

In [3]:
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

In [4]:
test = conv_bn(3, 10, 1, 1)

In [5]:
test(torch.randn(1, 3, 32, 32))

tensor[1, 10, 34, 34] n=11560 x∈[-2.491, 2.663] μ=-1.196e-09 σ=0.577 grad ConvolutionBackward0

In [6]:
class RepBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 inference_mode=False, stride=1, padding=0, dilation=1,
                 groups=1, num_conv_branches=1, zero_params=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.inference_mode = inference_mode
        self.stride = stride
        self.groups = groups
        self.num_conv_branches = num_conv_branches

        
        self.se = nn.Identity()
        self.activation = nn.ReLU()

        if inference_mode:
            self.reparam_conv = nn.Conv2d(in_channels=in_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation,
                                          groups=groups,
                                          bias=True)
        else:
            # Re-parameterizable skip connection
            self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \
                if out_channels == in_channels and stride == 1 else None

            # Re-parameterizable conv branches
            rbr_conv = list()
            for _ in range(self.num_conv_branches):
                rbr_conv.append(self._conv_bn(kernel_size=kernel_size,
                                              padding=padding))
            self.rbr_conv = nn.ModuleList(rbr_conv)

            # Re-parameterizable scale branch
            self.rbr_scale = None
            if kernel_size > 1:
                self.rbr_scale = self._conv_bn(kernel_size=1,
                                               padding=0)

            if zero_params:
                zero_module(self.rbr_skip) if self.rbr_skip is not None else ...
                zero_module(self.rbr_conv) 
                zero_module(self.rbr_scale) if self.rbr_scale is not None else ...
    def forward(self, x):
        # Inference mode forward pass.
        if self.inference_mode:
            return self.activation(self.se(self.reparam_conv(x)))

        # Multi-branched train-time forward pass.
        # Skip branch output
        identity_out = 0
        if self.rbr_skip is not None:
            identity_out = self.rbr_skip(x)

        # Scale branch output
        scale_out = 0
        if self.rbr_scale is not None:
            scale_out = self.rbr_scale(x)

        # Other branches
        out = scale_out + identity_out
        for ix in range(self.num_conv_branches):
            out += self.rbr_conv[ix](x)

        return self.activation(self.se(out))
    def _conv_bn(self,
                 kernel_size: int,
                 padding: int) -> nn.Sequential:
        """ Helper method to construct conv-batchnorm layers.

        :param kernel_size: Size of the convolution kernel.
        :param padding: Zero-padding size.
        :return: Conv-BN module.
        """
        mod_list = nn.Sequential()
        mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,
                                              out_channels=self.out_channels,
                                              kernel_size=kernel_size,
                                              stride=self.stride,
                                              padding=padding,
                                              groups=self.groups,
                                              bias=False))
        mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
        return mod_list
    
    def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
        """ Method to fuse batchnorm layer with preceeding conv layer.
        Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95

        :param branch:
        :return: Tuple of (kernel, bias) after fusing batchnorm.
        """
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = torch.zeros((self.in_channels,
                                            input_dim,
                                            self.kernel_size,
                                            self.kernel_size),
                                           dtype=branch.weight.dtype,
                                           device=branch.weight.device)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim,
                                 self.kernel_size // 2,
                                 self.kernel_size // 2] = 1
                self.id_tensor = kernel_value
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

In [50]:
rep = RepBlock(in_channels=3, out_channels=16, kernel_size=3, inference_mode=False, padding=1)

In [51]:
rep(torch.randn(1, 3, 32, 32))

tensor[1, 16, 32, 32] n=16384 x∈[0., 5.577] μ=0.560 σ=0.820 grad ReluBackward0

In [52]:
rep2 = RepBlock(in_channels=3, out_channels=16, kernel_size=3, inference_mode=True, padding=1)

In [53]:
rep2(torch.randn(1, 3, 32, 32))

tensor[1, 16, 32, 32] n=16384 x∈[0., 2.215] μ=0.252 σ=0.346 grad ReluBackward0

In [54]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [56]:
count_parameters(rep)

544

In [55]:
count_parameters(rep2)

448

In [7]:
from pytorch_wavelets import DWTInverse, DWTForward

dwt = DWTForward(J=1, mode='zero', wave='db1')
idwt = DWTInverse(mode="zero", wave="db1")

def img_to_dwt(img):
    low, high = dwt(img)
    b, _, _, h, w = high[0].size()
    high = high[0].view(b, -1, h, w)
    freq = torch.cat([low, high], dim=1)
    return freq

def dwt_to_img(img):
    b, c, h, w = img.size()
    low = img[:, :3, :, :]
    high = img[:, 3:, :, :].view(b, 3, 3, h, w)
    return idwt((low, [high]))

In [8]:
class Downsample(nn.Module):
    def __init__(self, channels, out_channels=None, kernel=3, stride=None, padding=1): # use_conv,
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        # self.use_conv = use_conv
        stride = 2 if stride is None else stride

        # if use_conv:
        self.op = nn.Conv2d(channels, self.out_channels, kernel_size=kernel, stride=stride, padding=padding)
        # else:
        #     self.op = DWTForward(J=1, mode='zero', wave='db1')

    # def img_to_dwt(self, img):
    #     low, high = self.op(img)
    #     b, _, _, h, w = high[0].size()
    #     high = high[0].view(b, -1, h, w)
    #     freq = torch.cat([low, high], dim=1)
    #     return freq

    def forward(self, x):
        assert x.shape[1] == self.channels
        x = self.op(x)
        return x
        # assert x.shape[1] == self.channels
        # if self.use_conv:
        #     x = self.op(x)
        # else:
        #     x = self.img_to_dwt(x)
        # return x
            

class Upsample(nn.Module):
    def __init__(self, channels, out_channels=None, paddding=1): # use_conv,
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        # self.use_conv = use_conv
        # if use_conv:
        self.op = nn.Conv2d(self.channels, self.out_channels, kernel_size=3, stride=1, padding=1)
        # else:
        #     self.op = DWTInverse(mode="zero", wave="db1")
            
    # def dwt_to_img(self, img):
    #     b, c, h, w = img.size()
    #     low = img[:, :3, :, :]
    #     high = img[:, 3:, :, :].view(b, 3, 3, h, w)
    #     return self.op((low, [high]))
    
    def forward(self, x):
        assert x.shape[1] == self.channels
        # if self.use_conv:
        #     x = F.interpolate(x, scale_factor=2, mode='bicubic')
        #     x = self.op(x)
        # else:
        #     x = self.dwt_to_img(x)
        x = F.interpolate(x, scale_factor=2, mode='bicubic')
        x = self.op(x)
        return x

In [178]:
test = DWTForward(J=1, mode='zero', wave='haar')

In [179]:
low, high = test(torch.randn(1, 13, 256, 256))

In [183]:
LH, HL, HH = tuple(rearrange(high[0] , 'b c n h w -> n b c h w ', n=3))

In [184]:
LH.shape

torch.Size([1, 13, 128, 128])

In [186]:
low.shape

torch.Size([1, 13, 128, 128])

In [189]:
high[0]

tensor[1, 13, 3, 128, 128] n=638976 x∈[-4.878, 4.881] μ=0.000 σ=1.000

In [192]:
torch.concat((low.view(1, 13, 1, 128, 128), high[0]), dim=2) * torch.randn(1, 4, 1, 1).view(1, -1, 4, 1, 1)

tensor[1, 13, 4, 128, 128] n=851968 x∈[-5.172, 4.949] μ=-0.001 σ=0.873

In [180]:
high[0]

tensor[1, 13, 3, 128, 128] n=638976 x∈[-4.878, 4.881] μ=0.000 σ=1.000

In [157]:
high[0].shape

torch.Size([1, 3, 3, 128, 128])

In [155]:
LH, HL, HH = tuple(rearrange(high[0] , 'b c n h w -> n b c h w ', n=3))

In [156]:
q.shape

torch.Size([1, 3, 128, 128])

In [118]:
b, _, _, h, w = high[0].size()
high = high[0].view(b, -1, h, w)
freq = torch.cat([low, high], dim=1)

In [119]:
freq.shape

torch.Size([1, 12, 128, 128])

In [120]:
op = DWTInverse(mode="zero", wave="haar")

In [123]:
freq.shape

torch.Size([1, 12, 128, 128])

In [121]:
def dwt_to_img(img):
    b, c, h, w = img.size()
    low = img[:, :3, :, :]
    high = img[:, 3:, :, :].view(b, 3, 3, h, w)
    return op((low, [high]))

In [122]:
dwt_to_img(freq)

tensor[1, 3, 256, 256] n=196608 x∈[-4.776, 4.699] μ=-0.003 σ=1.001

In [127]:
avgpool = nn.AvgPool2d(kernel_size=4)

In [129]:
avgpool(torch.randn(1, 3, 256, 256)).shape

torch.Size([1, 3, 64, 64])

In [126]:
avgpool = nn.AdaptiveAvgPool2d(2)
avgpool(torch.randn(1, 3, 128, 128)).view(1, 3)

RuntimeError: shape '[1, 3]' is invalid for input of size 12

In [9]:
class CostomAdaptiveAvgPool2D(nn.Module):
    
    def __init__(self, output_size):
        
        super().__init__()
        
        self.output_size = output_size
        
    def forward(self, x):
        
        H_in,  W_in  = x.shape[2:]
        H_out, W_out = [self.output_size, self.output_size] \
                       if isinstance(self.output_size, int) \
                       else self.output_size
        
        out_i = []
        for i in range(H_out):
            out_j = []
            for j in range(W_out):
                
                hs = int(floor(i * H_in / H_out))
                he = int(ceil((i+1) * H_in / H_out))
                
                ws = int(floor(j * W_in / W_out))
                we = int(ceil((j+1) * W_in / W_out))
                
                # print(hs, he, ws, we)
                kernel_size = [he-hs, we-ws]
                
                out = F.avg_pool2d(x[:, :, hs:he, ws:we], kernel_size) 
                out_j.append(out)
                
            out_j = torch.concat(out_j, -1)
            out_i.append(out_j)
            
        out_i = torch.concat(out_i, -2)
        return out_i

In [10]:
class WaveletGating(nn.Module):
    def __init__(self, channels, pool_size=4):
        super().__init__()
        self.pool_size = pool_size
        # self.avgpool = nn.AvgPool2d(self.pool_size)
        self.avgpool = CostomAdaptiveAvgPool2D(1)
        self.fc = nn.Sequential(nn.Conv2d(channels, channels //2, kernel_size=1, bias=False),
                                nn.ReLU(True),
                                nn.Conv2d(channels //2, 4, kernel_size=1, bias=False),
                                nn.Sigmoid())

    def forward(self, x):
        b, c, h, w = x.size()
        x = self.avgpool(x)
        x = self.fc(x)
        return x


class WGDown(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dwt = DWTForward(J=1, wave='haar')
        self.wavelet_gating = WaveletGating(channels)
        
    def img_to_dwt(self, img):
        low, high = self.dwt(img)
        b, _, _, h, w = high[0].size()
        high = high[0].view(b, -1, h, w)
        freq = torch.cat([low, high], dim=1)
        return freq
    
    def forward(self, x):
        """
        Args:
            x : [B, C_in, H, W]
        Returns:
            [B, C_in, H//2, W//2]
        """
        # dwt를 하고 나오는 순서 : LL, LH, HL, HH
        # LL, high = self.dwt(x)
        # LH, HL, HH = tuple(rearrange(high[0] , 'b c n h w -> n b c h w ', n=3))
        # b, _, _, h, w = high[0].size()
        # [1, 4, 1, 1]
        # low, high = self.dwt
        # score = self.wavelet_gating(x)
        # [B, C_in, 4, H//2, W//2]
        # result = torch.concat((low.unsqueeze(2), high[0]), dim=2) * score.unsqueeze(1)
        # return result.flatten(start_dim=1, end_dim=2)
        LL, high = self.dwt(x)
        # LH, HL, HH = tuple(rearrange(high[0] , 'b c n h w -> n b c h w ', n=3))
        score = self.wavelet_gating(x)
        # LL = LL * score[:, 0, ...]
        # LH = LH * score[:, 1, ...]
        # HL = HL * score[:, 2, ...]
        # HH = HH * score[:, 3, ...]
        # result = LL + LH + HL + HH
        # Stack along a new dimension
        # tensors = torch.stack([LL, LH, HL, HH], dim=1) # tensors shape: [B, 4, C_in, H//2, W//2]
        # print(low.unsqueeze(2).shape)
        # print(high[0])
        result = torch.concat((LL.unsqueeze(2), high[0]), dim=2) * score.unsqueeze(1)
        result = result.sum(dim=1)
        return result

class WGUp(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.idwt = DWTInverse(mode="zero", wave='haar')
        self.wavelet_gating = WaveletGating(channels)

    def dwt_to_img(img):
        b, c, h, w = img.size()
        low = img[:, :3, :, :]
        high = img[:, 3:, :, :].view(b, 3, 3, h, w)
        return op((low, [high]))
    def forward(self, x):
        """
        Args:
            x : [B, C_in, H, W]
        Returns:
            [B, C_in, 2*H, 2*W]
        """
        score = self.wavelet_gating(x) # [B, 4, 1, 1]
        # x = x.chunk(4, dim=1)
        b, c, h, w = x.size()
        x_reshaped = x.view(b, 4, c//4, h, w)
        # score_reshaped = score.expand(b, 4, c//4, h, w)
        freq = x_reshaped * score.unsqueeze(2)
        
        result = self.idwt((freq[:,0,...], [freq[:, 1:, ...].flatten(start_dim=1, end_dim=2)]))
        return result

In [255]:
wg = WGUp(100)

In [248]:
torch.randn(16, 4, 6, 32, 32)[:, :1, ...]

tensor[16, 1, 6, 32, 32] n=98304 x∈[-5.159, 4.541] μ=-0.002 σ=0.999

In [249]:
torch.randn(16, 4, 6, 32, 32)[:, 1:, ...]

tensor[16, 3, 6, 32, 32] n=294912 x∈[-5.079, 4.457] μ=-0.000 σ=1.000

In [257]:
wg(torch.randn(4, 100, 32, 32))

ValueError: too many values to unpack (expected 3)

In [222]:
torch.chunk(torch.randn(1, 4, 32, 32), 4, 1)

(tensor[1, 1, 32, 32] n=1024 x∈[-2.855, 3.688] μ=0.017 σ=0.984,
 tensor[1, 1, 32, 32] n=1024 x∈[-3.620, 3.164] μ=0.001 σ=0.958,
 tensor[1, 1, 32, 32] n=1024 x∈[-3.202, 3.188] μ=0.005 σ=0.994,
 tensor[1, 1, 32, 32] n=1024 x∈[-2.847, 3.910] μ=0.037 σ=0.974)

In [221]:
torch.randn(1, 4, 32, 32).chunk(4, dim=1)

(tensor[1, 1, 32, 32] n=1024 x∈[-3.024, 2.899] μ=-0.035 σ=0.984,
 tensor[1, 1, 32, 32] n=1024 x∈[-3.229, 2.876] μ=-0.038 σ=0.970,
 tensor[1, 1, 32, 32] n=1024 x∈[-3.727, 4.430] μ=-0.019 σ=1.021,
 tensor[1, 1, 32, 32] n=1024 x∈[-3.659, 2.933] μ=0.032 σ=1.029)

In [219]:
wgd = WGDown(4)

In [220]:
wgd(torch.randn(32, 4, 16, 16)).shape

torch.Size([1, 13, 1, 128, 128])
tensor[32, 4, 3, 8, 8] n=24576 x∈[-4.129, 3.749] μ=0.005 σ=0.998


torch.Size([32, 4, 8, 8])

In [202]:
torch.randn(30, 4, 1, 1)[:,0,...]

tensor[30, 1, 1] x∈[-2.503, 1.897] μ=0.061 σ=1.051

In [176]:
wg = WaveletGating(10)

In [177]:
wg(torch.randn(1, 10, 256, 256))

tensor[1, 4, 1, 1] x∈[0.500, 0.500] μ=0.500 σ=0.000 grad SigmoidBackward0 [[[[0.500]], [[0.500]], [[0.500]], [[0.500]]]]

In [7]:
blk = RepBlock(3, 10, 3, padding=1, stride=2)

In [8]:
blk(torch.randn(1, 3, 32, 32)).shape

torch.Size([1, 10, 16, 16])

In [9]:
class UNetModel2(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 model_channels, num_res_blocks=2, channel_mult=(1,2,4,8),
                 ):
        super().__init__()
        self.num_resolutions = len(channel_mult)

        

In [40]:
class UNetModel(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 model_channels, num_res_blocks=2, channel_mult=(1,2,4,8), attention_resolutions=[4,2,1], large_kerenl=None, base_kernel=3, freq_domain=False
                 ):
        super().__init__()
        self.channel_mult = channel_mult
        self.num_res_blocks= num_res_blocks

        if isinstance(num_res_blocks, int):
            self.num_res_blocks = len(channel_mult) * [num_res_blocks] # [2,2,2,2]
        else:
            if len(num_res_blocks) != len(channel_mult):
                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
                                 "as a list/tuple (per-level) with the same length as channel_mult")
            self.num_res_blocks = num_res_blocks

        self.base_kernel = base_kernel
        self.large_kernel = 7 if large_kerenl is None else base_kernel
        self.large_kernel_level = int(len(channel_mult)/2)
        
        ch = int(channel_mult[0] * model_channels)
        input_block_chans = [ch]
        stem = [nn.Conv2d(in_channels, ch, kernel_size=3, padding=1)]
        if freq_domain:
            down = WGDown
            up = WGUp
        else:
            down = Downsample
            up = Upsample
        # stem.append(WGDown(use_conv=False))
        self.input_blocks = nn.ModuleList([nn.Conv2d(in_channels, ch, kernel_size=3, padding=1),])
                                           # down(ch)])

        # self.input_blocks.append(stem)
        self.output_blocks = nn.ModuleList([])

        # channel_mult = [1,2,4,4]
        # 0, 1
        # nr = 2
        # 0, 1
        # 
        ds = 1
        for level, mult in enumerate(channel_mult):
            if level <= self.large_kernel_level:
                kernel_size = self.large_kernel
                padding = 3
            else:
                kernel_size = self.base_kernel
                padding = 1
            for nr in range(self.num_res_blocks[level]):
                print(nr)
                layers = [RepBlock(ch, int(mult * model_channels), kernel_size=kernel_size, padding=padding)]
                ch = mult * model_channels
                
                if ds in attention_resolutions:
                    layers.append(RepBlock(ch, ch, kernel_size=kernel_size, padding=padding))
                
                self.input_blocks.append(nn.Sequential(*layers))
                input_block_chans.append(ch)
            # 0 vs 3
            if level != len(channel_mult) -1:
                self.input_blocks.append(Downsample(ch))#, use_conv=True))
                input_block_chans.append(ch)
                ds *= 2
            # self.input_blocks.append(nn.Sequential(*layers))
        
        self.middle_block = nn.Sequential(
            RepBlock(ch, ch, kernel_size=3, padding=1),
            RepBlock(ch, ch, kernel_size=3, padding=1)
        )

        # [(3, 4), (2, 4), (1, 2), (0, 1)]
        for level, mult in list(enumerate(channel_mult))[::-1]:
            if level >= self.large_kernel_level:
                kernel_size = self.large_kernel
                padding = 3
            else:
                kernel_size = self.base_kernel
                padding = 1
            # 0, 1, 2
            for i in range(self.num_res_blocks[level] + 1):
                ich = input_block_chans.pop()
                layers = [RepBlock(ch + ich, model_channels*mult, kernel_size=kernel_size, padding=padding)]
                ch = model_channels * mult

                if ds in attention_resolutions:
                    layers.append(RepBlock(ch, ch, kernel_size=kernel_size, padding=padding))
                if level and i == self.num_res_blocks[level]:
                    layers.append(Upsample(ch))#, use_conv=True))
                    ds //= 2
                self.output_blocks.append(nn.Sequential(*layers))
        
        self.out = nn.Sequential(RepBlock(model_channels, out_channels, 3, padding=1, zero_params=True))

    def forward(self, x):
        """
        args
        x: [N x C x H x W]
        returns
        an [N x C x H x W]
        """
        hs = []
        h = x
        for module in self.input_blocks:
            h = module(h)
            hs.append(h)
        h = self.middle_block(h)
        # hs = 12
        # torch.Size([1, 10, 256, 256])
        # torch.Size([1, 10, 256, 256])
        # torch.Size([1, 10, 256, 256])
        # torch.Size([1, 10, 128, 128])
        # torch.Size([1, 20, 128, 128])
        # torch.Size([1, 20, 128, 128])
        # torch.Size([1, 20, 64, 64])
        # torch.Size([1, 40, 64, 64])
        # torch.Size([1, 40, 64, 64])
        #  torch.Size([1, 40, 32, 32])
        # torch.Size([1, 40, 32, 32])
        # torch.Size([1, 40, 32, 32])
        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1) # [1, 40, 128, 128] (upsample 때문에) 과 [1, 40, 64, 64]를 concat하려고 함
            h = module(h)
        return self.out(h)

In [41]:
model = UNetModel(3, 3, 10, channel_mult=(1,2,4,4))

0
1
0
1
0
1
0
1


In [42]:
# model.input_blocks

In [43]:
# model.output_blocks

In [44]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [45]:
count_parameters(model)

1699162

In [46]:
model(torch.randn(1, 3, 256, 256)).shape

torch.Size([1, 3, 256, 256])

In [None]:
def get_fused_bn_to_conv_state_dict(conv, bn):
    bn_mean, bn_var, bn_gamma, bn_beta = (bn.running_mean, bn.running_var, 
                                          bn.weight, bn.bias)
    

In [None]:
forward(x, style):
    modulatio = sefl.get_modulatin(style)

def get_mdoulati0n(sefl,x style):
    style = self.modulation(style).view(style.sise(0), -1, 1,1)
    modulation = self.scale * style
    return modulation

def get_demodulation(self, style):
    w = self.weight.unsqueeze(0)
    

In [None]:
def conv_bn(in_channels, out_channels, stride, padding, groups=1):
    result = nn.Sequential()
    result.add_module('conv3x3', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                        kernel_size=3, stride=stride, padding=padding, groups=groups,bias=False))
    result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    result.add_module('conv1x1', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                        kernel_size=1, stride=stride, padding=padding, groups=groups,bias=False))

    return result


class RepVGGBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super(self).__init__()
        self.identity = nn.Identity()
        self.block1 = conv_bn()
        self.block2 = conv_bn()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        res = x
        x = self.block1(x)
        x += self.block2(res)
        x += self.identity(res)
        x = re
class DWConv(nn.Module):
    def __init__(self, channels_in, channels_out, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(channel_in, 1, kernel_size, kernel_size))
        self.weight_permute = nn.Parameter(torch.randn(channel_out, channel_in, 1, 1))

        