In [25]:
import torchvision.models as models
from torchsummary import summary
import torch
import torch.nn as nn 

import antialiased_cnns

In [15]:
I = torch.empty((3, 224, 224)).cuda() # BCHW

In [30]:
resnet = antialiased_cnns.resnet18(pool_only=True)

resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

layer_list = ["conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3"]
layers = [getattr(resnet, name) for name in layer_list]
resnet = nn.Sequential(*layers)

resnet.cuda()
summary(resnet, I.shape)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4         [-1, 64, 111, 111]               0
   ReflectionPad2d-5         [-1, 64, 114, 114]               0
          BlurPool-6           [-1, 64, 56, 56]               0
            Conv2d-7           [-1, 64, 56, 56]          36,864
       BatchNorm2d-8           [-1, 64, 56, 56]             128
              ReLU-9           [-1, 64, 56, 56]               0
           Conv2d-10           [-1, 64, 56, 56]          36,864
      BatchNorm2d-11           [-1, 64, 56, 56]             128
             ReLU-12           [-1, 64, 56, 56]               0
       BasicBlock-13           [-1, 64, 56, 56]               0
           Conv2d-14           [-1, 64,

In [18]:
import ransacflow.model.model_orig as model
import torch.nn.functional as F
import torch

In [21]:
feature = model.FeatureExtractor().cuda()
summary(feature, I.shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,728
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 223, 223]               0
   ReflectionPad2d-5         [-1, 64, 225, 225]               0
        Downsample-6         [-1, 64, 112, 112]               0
            Conv2d-7         [-1, 64, 112, 112]          36,864
       BatchNorm2d-8         [-1, 64, 112, 112]             128
              ReLU-9         [-1, 64, 112, 112]               0
           Conv2d-10         [-1, 64, 112, 112]          36,864
      BatchNorm2d-11         [-1, 64, 112, 112]             128
             ReLU-12         [-1, 64, 112, 112]               0
       BasicBlock-13         [-1, 64, 112, 112]               0
           Conv2d-14         [-1, 64, 1

kernelSize = 7
batchSize = 4
imgSize = 224
margin = 88
I = torch.zeros((batchSize, 3, imgSize, imgSize)) # BCHW


In [48]:
I2 = torch.cat((I, I), dim=0).cuda()

print(I.shape)
print(I2.shape)

torch.Size([4, 3, 224, 224])
torch.Size([8, 3, 224, 224])


In [43]:
index = torch.arange(batchSize * 2)
print(index.shape)
indexRoll = torch.roll(index, batchSize)
print(indexRoll.shape)

print(index)
print(indexRoll)

torch.Size([8])
torch.Size([8])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([4, 5, 6, 7, 0, 1, 2, 3])


In [51]:
network = {
    "netFeatCoarse": model.FeatureExtractor(),
    "netCorr": model.CorrNeigh(kernelSize),
    "netFlowCoarse": model.NetFlow(kernelSize, "netFlowCoarse"),
    "netMatch": model.NetFlow(kernelSize, "netMatch"),
}

for key in list(network.keys()):
    network[key].cuda()


In [52]:
def predFlowCoarse(corrKernel, NetFlowCoarse, grid, up8X = True) :

    flowCoarse = NetFlowCoarse(corrKernel, up8X) ## output is with dimension B, 2, W, H
    _, _, w, h = flowCoarse.size()
    flowGrad = flowCoarse.narrow(2, 1, w-1).narrow(3, 1, h-1) - flowCoarse.narrow(2, 0, w-1).narrow(3, 0, h-1)
    flowGrad = torch.norm(flowGrad, dim=1, keepdim=True)
    flowCoarse = flowCoarse.permute(0, 2, 3, 1)
    flowCoarse = torch.clamp(flowCoarse + grid, min=-1, max=1)

    return flowGrad, flowCoarse


def predMatchability(corrKernel21, NetMatchability, up8X = True) :

    matchability = NetMatchability(corrKernel21, up8X)

    return matchability

In [53]:
gridY = torch.linspace(-1, 1, steps = imgSize).view(1, -1, 1, 1).expand(1, imgSize,  imgSize, 1)
gridX = torch.linspace(-1, 1, steps = imgSize).view(1, 1, -1, 1).expand(1, imgSize,  imgSize, 1)
grid = torch.cat((gridX, gridY), dim=3).cuda()

In [63]:
maskMargin = torch.ones(batchSize * 2, 1, imgSize - 2 * margin, imgSize - 2 * margin)
maskMargin = F.pad(maskMargin, (margin, margin, margin, margin), "constant", 0)
maskMargin = maskMargin.cuda()

In [66]:
f = F.normalize(network["netFeatCoarse"](I2), p=2, dim=1)
corr = network['netCorr'](f[indexRoll], f)

finalGrad, final = model.predFlowCoarse(corr, network['netFlowCoarse'], grid)
print(finalGrad.shape)
print(final.shape)

match = model.predMatchability(corr, network['netMatch']) * maskMargin
print(match.shape)

matchCycle = F.grid_sample(match[indexRoll], final) * match
print(matchCycle.shape)

torch.Size([8, 1, 223, 223])
torch.Size([8, 224, 224, 2])
torch.Size([8, 1, 224, 224])
torch.Size([8, 1, 224, 224])




In [32]:
corr = network["netCorr"](I2, I2)
print(corr.shape)

torch.Size([8, 49, 224, 224])


In [21]:
28 * 8

224