In [1]:
import torch 
import torch.nn as nn

In [11]:
class Layer_C(nn.Module):
    def __init__(self, k, stride=2, padding=0, norm=True):
        super().__init__()
        self.conv = nn.LazyConv2d(k,(4,4),stride=stride,padding=padding)
        if norm:
            self.instance_norm = nn.InstanceNorm2d(k)
        else:
            self.instance_norm = nn.Identity()
        self.act = nn.LeakyReLU(0.2)
    def forward(self, x):
        x = self.conv(x)
        x = self.instance_norm(x)
        x = self.act(x)
        return x

In [13]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            Layer_C(64,2,1,norm=False),
            Layer_C(128,2,1),
            Layer_C(256,2,1),
            Layer_C(512,1,1),
            Layer_C(1,1,1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = self.layers(x)
        return x

In [17]:
a = torch.randn(64,3,256,256)

In [18]:
D = Discriminator()

In [20]:
out = D(a)

In [29]:
import einops

In [33]:
out2 = einops.rearrange(out, 'B C H W -> B (C H W)')

In [37]:
torch.mean(out2, dim=1, keepdim=True)

tensor([[0.5683],
        [0.5666],
        [0.5672],
        [0.5682],
        [0.5676],
        [0.5671],
        [0.5665],
        [0.5678],
        [0.5663],
        [0.5679],
        [0.5666],
        [0.5671],
        [0.5658],
        [0.5676],
        [0.5675],
        [0.5672],
        [0.5655],
        [0.5670],
        [0.5677],
        [0.5672],
        [0.5677],
        [0.5675],
        [0.5674],
        [0.5673],
        [0.5679],
        [0.5665],
        [0.5680],
        [0.5679],
        [0.5663],
        [0.5659],
        [0.5684],
        [0.5674],
        [0.5672],
        [0.5674],
        [0.5675],
        [0.5664],
        [0.5673],
        [0.5698],
        [0.5673],
        [0.5667],
        [0.5672],
        [0.5683],
        [0.5670],
        [0.5668],
        [0.5679],
        [0.5663],
        [0.5676],
        [0.5659],
        [0.5678],
        [0.5669],
        [0.5653],
        [0.5678],
        [0.5680],
        [0.5679],
        [0.5659],
        [0