In [None]:

# A-pmerge based classifier
class ApmergeClassifier(nn.Module):
    def __init__(
        self,stride,input_resolution,
        dim,num_classes=4,conv_padding_mode='circular'
    ):
        super().__init__()
        self.stride=stride
        # Pooling layer for A-pmerge
        pool_layer = partial(
            PolyphaseInvariantDown2D,
            component_selection=max_p_norm,
            antialias_layer=None
        )
        # A-pmerge
        # No adaptive window selection
        # for illustration purposes
        self.apmerge = AdaptivePatchMerging(
            input_resolution=input_resolution,
            dim=dim,
            pool_layer=pool_layer
            conv_padding_mode=conv_padding_mode,
            stride=stride,
            window_selection=None,
            window_size=None,
        )
        # Global pooling and head
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(dim*stride,num_classes)
    def forward(self,x):
        # Reshape
        B,C,H,W = x.shape
        x = x.permute(0,2,3,1).reshape(B,H*W,C)
        # Adaptive patch merge
        x = self.apmerge(x)
        # Reshape back
        x = x.reshape(
        B,H//self.stride,W//self.stride,C*self.stride,
        ).permute(0,3,1,2)
        # Global average pooling
        x = torch.flatten(self.avgpool(x),1)
        # Classification head
        x = self.fc(x)
        return x
# Input tokens
B,C,H,W = 1,3,8,8
stride = 2
x = torch.randn(B,C,H,W).cuda().double()
# Shifted input
shift = torch.randint(-3,3,(2,))
x_shift = torch.roll(input=x,dims=(2,3),shifts=(shift[0],shift[1]))
# A-pmerge classifier
model = ApmergeClassifier(
stride=stride,
input_resolution=(H,W),
dim=C).cuda().double().eval()
# Predict
y = model(x)
y_shift = model(x_shift)
err = torch.norm(y-y_shift)
assert(torch.allclose(y,y_shift))
# Check circularly shift invariance
print('y: {}'.format(y))
print('y_shift: {}'.format(y_shift))
print("error: {}".format(err))