In [2]:
import torch
import torch.nn as nn

In [3]:
import sys
sys.path.append('/home/yhuang2/PROJs/LS4GAN')

In [4]:
from toygan_hybrid_vit_cnn.toygan.torch.select import ( 
    get_activ_layer,
    extract_name_kwargs
)
from toygan_hybrid_vit_cnn.toygan.torch.layers.mixturebatch import MixtureBatch

In [5]:
num_channels = 2
mixture_batch_1 = MixtureBatch(num_channels, coef_norm=1, coef_iden=0, norm_kwargs={'momentum': .2})
mixture_batch_2 = MixtureBatch(num_channels, coef_norm=0, coef_iden=1, norm_kwargs={'momentum': .2})
mixture_batch_3 = MixtureBatch(num_channels, coef_norm=1, coef_iden=1, norm_kwargs={'momentum': .2})
data = torch.randint(5, (2, num_channels, 2, 2)).type(torch.float32)
a1 = mixture_batch_1(data)
a2 = mixture_batch_2(data)
a3 = mixture_batch_3(data)
print((a1 + a2) / 2 == a3)

tensor([[[[True, True],
          [True, True]],

         [[True, True],
          [True, True]]],


        [[[True, True],
          [True, True]],

         [[True, True],
          [True, True]]]])


In [6]:
class ConvEmbeddingBlock(nn.Module):
    """
    NOTE: try to work with image with size that are powers of 2
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        stem_norm,
        stem_activ,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        # convolution
        self.conv  = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=4,
            padding=1,
            stride=2
        )
        
        # normalization
        name, kwargs = extract_name_kwargs(stem_norm)
        if name == 'mixturebatch':
            self.norm = MixtureBatch(out_channels, **kwargs)
        else:
            layer_norm_fn = get_norm_layer_fn(stem_norm)
            self.norm = layer_norm_fn(out_channels)
        
        # activation
        self.activ = get_activ_layer(stem_activ)


    def forward(self, x):
        return self.activ(self.norm(self.conv(x)))

In [7]:
stem_norm = {
    'name': 'mixturebatch',
    'coef_norm': 1,
    'coef_iden': 1,
    'norm_kwargs': {'momentum': .2}
}
layer = ConvEmbeddingBlock(1, 2, stem_norm, 'relu')
layer

ConvEmbeddingBlock(
  (conv): Conv2d(1, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (norm): MixtureBatch(
    (norm_layer): BatchNorm2d(2, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
  )
  (activ): ReLU()
)

In [21]:
data = torch.randint(5, (3, 1, 1)).type(torch.float32)
print(data)

data_s = data.squeeze()

data_r = data_s.unsqueeze(-1).unsqueeze(-1)
print(data_r == data)

tensor([[[2.]],

        [[1.]],

        [[3.]]])
tensor([[[True]],

        [[True]],

        [[True]]])
