In [1]:
import sys
sys.path.append("../learn_poly_sampling")

In [2]:
import numpy as np 
import torch 
import torch.nn as nn
from PIL import Image
from functools import partial
from layers import get_logits_model, PolyphaseInvariantDown2D, LPS
from layers.polydown import set_pool

In [3]:
# Define Model
class SimpleClassifier(nn.Module):
    def __init__(self, num_classes=3,padding_mode='circular'):
        # Conv. Layer
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1, padding_mode=padding_mode)
        # Learnable Polyphase Downsampling Layer
        self.lpd = set_pool(partial(
            PolyphaseInvariantDown2D,
            component_selection=LPS,
            get_logits=get_logits_model('LPSLogitLayers'),
            pass_extras=False
            ),p_ch=32,h_ch=32)
        # Global Pooling + Classifier
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(32, num_classes)
    def forward(self,x):
        x = self.conv1(x)
        x = self.lpd(x)  # Use just as any down-sampling layer!
        x = torch.flatten(self.avgpool(x),1)
        return self.fc(x)

In [4]:
# Construct Model
torch.manual_seed(0)
model = SimpleClassifier().cuda().eval().double()
# Load Image
img = torch.from_numpy(np.array(Image.open('butterfly.png'))).permute(2,0,1)
img = img.unsqueeze(0).cuda().double()

In [5]:
# Check is circular shift invariant
y_orig = model(img).detach().cpu()
img_roll = torch.roll(img,shifts=(1, 1), dims=(-1, -2))
y_roll = model(img_roll).detach().cpu()
print("y_orig : %s" % y_orig)
print("y_roll : %s" % y_roll)
assert(torch.allclose(y_orig,y_roll)) # Check shift invariant
print("Norm(y_orig-y_roll): %e" % torch.norm(y_orig-y_roll))

y_orig : tensor([[-22.0681, -36.2678,  20.5928]], dtype=torch.float64)
y_roll : tensor([[-22.0681, -36.2678,  20.5928]], dtype=torch.float64)
Norm(y_orig-y_roll): 0.000000e+00
