In [265]:
import torch
from torch import nn
import torch.nn.functional as F

In [338]:
def subsample2(x):
    x1 = x[:, :, ::2, ::2]
    x2 = x[:, :, 1::2, ::2]
    x3 = x[:, :, ::2, 1::2]
    x4 = x[:, :, 1::2, 1::2]
    
    x1m, x2m, x3m, x4m = x1.max(), x2.max(), x3.max(), x4.max()
    maxval = max(x1m, x2m, x3m, x4m)
    if x1m == maxval:
        return x1
    elif x2m == maxval:
        return x2
    elif x3m == maxval:
        return x3
    else:
        return x4

In [339]:
class BNConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, h, w, stride=1):
        super(BNConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1, padding_mode='circular')
        self.stride = stride
        self.activation = F.relu
        
        nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')
        self.conv.bias.data.zero_()
        
    def forward(self, x):
        convout = self.conv(x)
        
        if self.stride == 2:
            convout = subsample2(convout)
        elif self.stride > 2:
            raise NotImplementedError(f"Equivariance for stride {self.stride} not implemented yet")
        
        return self.activation(convout)

In [340]:
class D_Conv(nn.Module):
    def __init__(self, in_ch=3, num_classes=10, alpha=1, h=32, w=32):
        super(D_Conv, self).__init__()
        modules = [BNConvBlock(in_ch, alpha, h=h, w=w, stride=1),
                   BNConvBlock(alpha, 2*alpha, h=h, w=w, stride=2),
                   BNConvBlock(2*alpha, 2*alpha, h=int(h/2), w=int(w/2), stride=1),
                   BNConvBlock(2*alpha, 4*alpha, h=int(h/2), w=int(w/2), stride=2),
                   BNConvBlock(4*alpha, 4*alpha, h=int(h/4), w=int(w/4), stride=1),
                   BNConvBlock(4*alpha, 8*alpha, h=int(h/4), w=int(w/4), stride=2),
                   BNConvBlock(8*alpha, 8*alpha, h=int(h/8), w=int(w/8), stride=1),
                   BNConvBlock(8*alpha, 16*alpha,h=int(h/8), w=int(w/8), stride=2),
                   BNConvBlock(16*alpha, 64*alpha,h=int(h/16), w=int(w/16), stride=2),
                  ]
        self.conv_net = nn.Sequential(*modules)
        
        self.activation = F.relu
        self.final = nn.Conv2d(64*alpha, num_classes, 1, stride=1, padding=0, padding_mode='circular')
        self.final.bias.data.zero_()
        
    def forward(self, x):
        out = x
        
        for layer in self.conv_net:
            print(1, out.mean(), out.std())
            out = layer(out)
            print(2, out.mean(), out.std())
            #plt.imshow(out[0, 0].detach().numpy())
            
        print(1, out.mean(), out.std())
        out = self.final(out)
        print(2, out.mean(), out.std())
        out = out.view(*out.shape[:2])
        print(3, out.mean(), out.std())
        return out

In [341]:
import matplotlib.pyplot as plt

In [342]:
model = D_Conv(in_ch=3)

In [343]:
B, C, H, W = 16, 3, 32, 32

u, v = 0, 16

img = torch.randn((B, C, 16, 16)) * 1
X = torch.zeros((B, C, H, W))
X[:, :, u:u+16, u:u+16] = img

X2 = torch.zeros((B, C, H, W))
X2[:, :, v:v+16, u:u+16] = img

print(X.mean(), X.std())
print(X2.mean(), X2.std())

tensor(0.0014) tensor(0.5051)
tensor(0.0014) tensor(0.5051)


In [344]:
model = D_Conv(in_ch=C)
Y = model(X)
Y2 = model(X2)

1 tensor(0.0014) tensor(0.5051)
2 tensor(0.1752, grad_fn=<MeanBackward0>) tensor(0.5403, grad_fn=<StdBackward0>)
1 tensor(0.1752, grad_fn=<MeanBackward0>) tensor(0.5403, grad_fn=<StdBackward0>)
2 tensor(0.1023, grad_fn=<MeanBackward0>) tensor(0.3889, grad_fn=<StdBackward0>)
1 tensor(0.1023, grad_fn=<MeanBackward0>) tensor(0.3889, grad_fn=<StdBackward0>)
2 tensor(0.1036, grad_fn=<MeanBackward0>) tensor(0.2750, grad_fn=<StdBackward0>)
1 tensor(0.1036, grad_fn=<MeanBackward0>) tensor(0.2750, grad_fn=<StdBackward0>)
2 tensor(0.2012, grad_fn=<MeanBackward0>) tensor(0.3802, grad_fn=<StdBackward0>)
1 tensor(0.2012, grad_fn=<MeanBackward0>) tensor(0.3802, grad_fn=<StdBackward0>)
2 tensor(0.0816, grad_fn=<MeanBackward0>) tensor(0.1921, grad_fn=<StdBackward0>)
1 tensor(0.0816, grad_fn=<MeanBackward0>) tensor(0.1921, grad_fn=<StdBackward0>)
2 tensor(0.0941, grad_fn=<MeanBackward0>) tensor(0.1744, grad_fn=<StdBackward0>)
1 tensor(0.0941, grad_fn=<MeanBackward0>) tensor(0.1744, grad_fn=<StdBackward