In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
#  Loading the Pretrained NCNet from Google Drive

!pip install gdown
import gdown

# File ID from your link
file_id = "10GZ0x3CmObKzbAg1GKQhrkPeSLRpD4Rp"
url = f"https://drive.google.com/uc?id={file_id}"

# Save location
output = "ncnet_checkpoint.pth"
gdown.download(url, output, quiet=False)




Downloading...
From: https://drive.google.com/uc?id=10GZ0x3CmObKzbAg1GKQhrkPeSLRpD4Rp
To: /kaggle/working/ncnet_checkpoint.pth
100%|██████████| 6.00k/6.00k [00:00<00:00, 9.97MB/s]


'ncnet_checkpoint.pth'

In [33]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Conv2d
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _quadruple
import torchvision.models as models
from collections import OrderedDict
import torch.utils.model_zoo as model_zoo

In [34]:
"""
This script is an adapted version of 
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 
The goal is to keep ResNet* as only feature extractor, 
so the code can be used independent of the types of specific tasks,
i.e., classification or regression. 
"""



def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    
class ResNet(nn.Module):
    PRETRAINED_URLs = {
        'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
        'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
        'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
        'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    }
    
    def __init__(self):
        super().__init__()
        
    def _build_model(self, block, layers):
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x, early_feat=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if early_feat:
            return x
        x = self.layer4(x)
        return x
    
    def forward_all(self, x, feat_list=[], early_feat=True):
        feat_list.append(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        feat_list.append(x)
        
        x = self.maxpool(x)
        x = self.layer1(x)
        feat_list.append(x)
        
        x = self.layer2(x)
        feat_list.append(x)
        
        x = self.layer3(x)
        feat_list.append(x)
        
        if not early_feat:
            x = self.layer4(x)
            feat_list.append(x)
    
    def load_pretrained_(self, ignore='fc'):
        print('Initialize ResNet using pretrained model from {}'.format(self.pretrained_url))
        state_dict = model_zoo.load_url(self.pretrained_url)
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if ignore in k:
                continue
            new_state_dict[k] = v
        self.load_state_dict(new_state_dict)

    def change_stride(self, target='layer3'):
        layer = getattr(self, target)
        layer[0].conv1.stride = (1, 1)
        layer[0].conv2.stride = (1, 1)
        layer[0].downsample[0].stride = (1, 1) 

class ResNet34(ResNet):
    def __init__(self):
        super().__init__()
        self.pretrained_url = self.PRETRAINED_URLs['resnet34']
        self._build_model(BasicBlock, [3, 4, 6, 3])

class ResNet50(ResNet):
    def __init__(self):
        super().__init__()
        self.pretrained_url = self.PRETRAINED_URLs['resnet50']
        self._build_model(Bottleneck, [3, 4, 6, 3])
        
class ResNet101(ResNet):
    def __init__(self):
        super().__init__()
        self.pretrained_url = self.PRETRAINED_URLs['resnet101']
        self._build_model(Bottleneck, [3, 4, 23, 3])


In [24]:
def Softmax1D(x,dim):
    x_k = torch.max(x,dim)[0].unsqueeze(dim)
    x -= x_k.expand_as(x)
    exp_x = torch.exp(x)
    return torch.div(exp_x,torch.sum(exp_x,dim).unsqueeze(dim).expand_as(x))

def featureL2Norm(feature):
    epsilon = 1e-6
    norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature)
    feat_norm = torch.div(feature,norm)
    return feat_norm

class FeatureExtraction(torch.nn.Module):
    def __init__(self, train_fe=False, feature_extraction_cnn='resnet101', feature_extraction_model_file='', normalization=True, last_layer='', use_cuda=True):
        super(FeatureExtraction, self).__init__()
        self.normalization = normalization
        self.feature_extraction_cnn=feature_extraction_cnn
        if feature_extraction_cnn == 'vgg':
            self.model = models.vgg16(pretrained=True)
            # keep feature extraction network up to indicated layer
            vgg_feature_layers=['conv1_1','relu1_1','conv1_2','relu1_2','pool1','conv2_1',
                         'relu2_1','conv2_2','relu2_2','pool2','conv3_1','relu3_1',
                         'conv3_2','relu3_2','conv3_3','relu3_3','pool3','conv4_1',
                         'relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4',
                         'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5']
            if last_layer=='':
                last_layer = 'pool4'
            last_layer_idx = vgg_feature_layers.index(last_layer)
            self.model = nn.Sequential(*list(self.model.features.children())[:last_layer_idx+1])
        # for resnet below
        resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4']
        if feature_extraction_cnn=='resnet101':
            self.model = models.resnet101(pretrained=True)            
            if last_layer=='':
                last_layer = 'layer3'                            
            resnet_module_list = [getattr(self.model,l) for l in resnet_feature_layers]
            last_layer_idx = resnet_feature_layers.index(last_layer)
            self.model = nn.Sequential(*resnet_module_list[:last_layer_idx+1])

        if feature_extraction_cnn=='resnet101fpn':
            if feature_extraction_model_file!='':
                resnet = models.resnet101(pretrained=True) 
                # swap stride (2,2) and (1,1) in first layers (PyTorch ResNet is slightly different to caffe2 ResNet)
                # this is required for compatibility with caffe2 models
                resnet.layer2[0].conv1.stride=(2,2)
                resnet.layer2[0].conv2.stride=(1,1)
                resnet.layer3[0].conv1.stride=(2,2)
                resnet.layer3[0].conv2.stride=(1,1)
                resnet.layer4[0].conv1.stride=(2,2)
                resnet.layer4[0].conv2.stride=(1,1)
            else:
                resnet = models.resnet101(pretrained=True) 
            resnet_module_list = [getattr(resnet,l) for l in resnet_feature_layers]
            conv_body = nn.Sequential(*resnet_module_list)
            self.model = fpn_body(conv_body,
                                  resnet_feature_layers,
                                  fpn_layers=['layer1','layer2','layer3'],
                                  normalize=normalization,
                                  hypercols=True)
            if feature_extraction_model_file!='':
                self.model.load_pretrained_weights(feature_extraction_model_file)

        if feature_extraction_cnn == 'densenet201':
            self.model = models.densenet201(pretrained=True)
            # keep feature extraction network up to denseblock3
            # self.model = nn.Sequential(*list(self.model.features.children())[:-3])
            # keep feature extraction network up to transitionlayer2
            self.model = nn.Sequential(*list(self.model.features.children())[:-4])
        if train_fe==False:
            # freeze parameters
            for param in self.model.parameters():
                param.requires_grad = False
        # move to GPU
        if use_cuda:
            self.model = self.model.cuda()
        
    def forward(self, image_batch):
        features = self.model(image_batch)
        if self.normalization and not self.feature_extraction_cnn=='resnet101fpn':
            features = featureL2Norm(features)
        return features
    
class FeatureCorrelation(torch.nn.Module):
    def __init__(self,shape='3D',normalization=True):
        super(FeatureCorrelation, self).__init__()
        self.normalization = normalization
        self.shape=shape
        self.ReLU = nn.ReLU()
    
    def forward(self, feature_A, feature_B):        
        if self.shape=='3D':
            b,c,h,w = feature_A.size()
            # reshape features for matrix multiplication
            feature_A = feature_A.transpose(2,3).contiguous().view(b,c,h*w)
            feature_B = feature_B.view(b,c,h*w).transpose(1,2)
            # perform matrix mult.
            feature_mul = torch.bmm(feature_B,feature_A)
            # indexed [batch,idx_A=row_A+h*col_A,row_B,col_B]
            correlation_tensor = feature_mul.view(b,h,w,h*w).transpose(2,3).transpose(1,2)
        elif self.shape=='4D':
            b,c,hA,wA = feature_A.size()
            b,c,hB,wB = feature_B.size()
            # reshape features for matrix multiplication
            feature_A = feature_A.view(b,c,hA*wA).transpose(1,2) # size [b,c,h*w]
            feature_B = feature_B.view(b,c,hB*wB) # size [b,c,h*w]
            # perform matrix mult.
            feature_mul = torch.bmm(feature_A,feature_B)
            # indexed [batch,row_A,col_A,row_B,col_B]
            correlation_tensor = feature_mul.view(b,hA,wA,hB,wB).unsqueeze(1)
        
        if self.normalization:
            correlation_tensor = featureL2Norm(self.ReLU(correlation_tensor))
            
        return correlation_tensor

class NeighConsensus(torch.nn.Module):
    def __init__(self, use_cuda=True, kernel_sizes=[3,3,3], channels=[10,10,1], symmetric_mode=True):
        super(NeighConsensus, self).__init__()
        self.symmetric_mode = symmetric_mode
        self.kernel_sizes = kernel_sizes
        self.channels = channels
        num_layers = len(kernel_sizes)
        nn_modules = list()
        for i in range(num_layers):
            if i==0:
                ch_in = 1
            else:
                ch_in = channels[i-1]
            ch_out = channels[i]
            k_size = kernel_sizes[i]
            nn_modules.append(Conv4d(in_channels=ch_in,out_channels=ch_out,kernel_size=k_size,bias=True))
            nn_modules.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*nn_modules)        
        if use_cuda:
            self.conv.cuda()

    def forward(self, x):
        if self.symmetric_mode:
            # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor),
            # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together
            x = self.conv(x)+self.conv(x.permute(0,1,4,5,2,3)).permute(0,1,4,5,2,3)
            # because of the ReLU layers in between linear layers, 
            # this operation is different than convolving a single time with the filters+filters^T
            # and therefore it makes sense to do this.
        else:
            x = self.conv(x)
        return x

def MutualMatching(corr4d):
    # mutual matching
    batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size()

    corr4d_B=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B]
    corr4d_A=corr4d.view(batch_size,fs1,fs2,fs3*fs4)

    # get max
    corr4d_B_max,_=torch.max(corr4d_B,dim=1,keepdim=True)
    corr4d_A_max,_=torch.max(corr4d_A,dim=3,keepdim=True)

    eps = 1e-5
    corr4d_B=corr4d_B/(corr4d_B_max+eps)
    corr4d_A=corr4d_A/(corr4d_A_max+eps)

    corr4d_B=corr4d_B.view(batch_size,1,fs1,fs2,fs3,fs4)
    corr4d_A=corr4d_A.view(batch_size,1,fs1,fs2,fs3,fs4)

    corr4d=corr4d*(corr4d_A*corr4d_B) # parenthesis are important for symmetric output 
    return corr4d

def MutualNorm(corr4d):
    # mutual matching
    batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size()

    corr4d_B=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B]
    corr4d_A=corr4d.view(batch_size,fs1,fs2,fs3*fs4)

    # get max
    corr4d_B_max,_=torch.max(corr4d_B,dim=1,keepdim=True)
    corr4d_A_max,_=torch.max(corr4d_A,dim=3,keepdim=True)

    eps = 1e-5
    corr4d_B=corr4d_B/(corr4d_B_max+eps)
    corr4d_A=corr4d_A/(corr4d_A_max+eps)

    corr4d_B=corr4d_B.view(batch_size,1,fs1,fs2,fs3,fs4)
    corr4d_A=corr4d_A.view(batch_size,1,fs1,fs2,fs3,fs4)
    return (corr4d_A*corr4d_B)

def maxpool4d(corr4d_hres,k_size=4):
    slices=[]
    for i in range(k_size):
        for j in range(k_size):
            for k in range(k_size):
                for l in range(k_size):
                    sl = corr4d_hres[:,0,i::k_size,j::k_size,k::k_size,l::k_size].unsqueeze(0)
                    slices.append(sl)

    slices=torch.cat(tuple(slices),dim=1)
    corr4d,max_idx=torch.max(slices,dim=1,keepdim=True)
    max_l=torch.fmod(max_idx,k_size)
    max_k=torch.fmod(max_idx.sub(max_l).div(k_size),k_size)
    max_j=torch.fmod(max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size),k_size)
    max_i=max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size).sub(max_j).div(k_size)
    # i,j,k,l represent the *relative* coords of the max point in the box of size k_size*k_size*k_size*k_size
    return (corr4d,max_i,max_j,max_k,max_l)

In [25]:
def corr_to_matches(corr4d, delta4d=None, ksize=1, do_softmax=True, scale='positive', 
                    invert_matching_direction=False, return_indices=True):
    to_cuda = lambda x: x.to(corr4d.device) if corr4d.is_cuda else x        
    batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size()  # b, c, h, w, h, w
    if scale=='centered':
        XA,YA=np.meshgrid(np.linspace(-1,1,fs2*ksize),np.linspace(-1,1,fs1*ksize))
        XB,YB=np.meshgrid(np.linspace(-1,1,fs4*ksize),np.linspace(-1,1,fs3*ksize))
    elif scale=='positive':
        # Upsampled resolution linear space
        XA,YA=np.meshgrid(np.linspace(0,1,fs2*ksize),np.linspace(0,1,fs1*ksize))
        XB,YB=np.meshgrid(np.linspace(0,1,fs4*ksize),np.linspace(0,1,fs3*ksize))
    # Index meshgrid for current resolution
    JA,IA=np.meshgrid(range(fs2),range(fs1)) 
    JB,IB=np.meshgrid(range(fs4),range(fs3))
    
    XA,YA=Variable(to_cuda(torch.FloatTensor(XA))),Variable(to_cuda(torch.FloatTensor(YA)))
    XB,YB=Variable(to_cuda(torch.FloatTensor(XB))),Variable(to_cuda(torch.FloatTensor(YB)))

    JA,IA=Variable(to_cuda(torch.LongTensor(JA).view(1,-1))),Variable(to_cuda(torch.LongTensor(IA).view(1,-1)))
    JB,IB=Variable(to_cuda(torch.LongTensor(JB).view(1,-1))),Variable(to_cuda(torch.LongTensor(IB).view(1,-1)))
    
    if invert_matching_direction:
        nc_A_Bvec=corr4d.view(batch_size,fs1,fs2,fs3*fs4)

        if do_softmax:
            nc_A_Bvec=torch.nn.functional.softmax(nc_A_Bvec,dim=3)

        # Max and argmax
        match_A_vals,idx_A_Bvec=torch.max(nc_A_Bvec,dim=3)
        score=match_A_vals.view(batch_size,-1)
        
        # Pick the indices for the best score
        iB=IB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size,-1).contiguous()  # b, h1*w1
        jB=JB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size,-1).contiguous()
        iA=IA.expand_as(iB).contiguous()
        jA=JA.expand_as(jB).contiguous()
        
    else:    
        nc_B_Avec=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B]
        if do_softmax:
            nc_B_Avec=torch.nn.functional.softmax(nc_B_Avec,dim=1)

        match_B_vals,idx_B_Avec=torch.max(nc_B_Avec,dim=1)
        score=match_B_vals.view(batch_size,-1)
        
        iA=IA.view(-1)[idx_B_Avec.view(-1)].view(batch_size,-1).contiguous() # b, h2*w2
        jA=JA.view(-1)[idx_B_Avec.view(-1)].view(batch_size,-1).contiguous() 
        iB=IB.expand_as(iA).contiguous()
        jB=JB.expand_as(jA).contiguous()
    
    if delta4d is not None: # relocalization, it is also the case ksize > 1
        # The shift within the pooling window reference to (0,0,0,0)
        delta_iA, delta_jA, delta_iB, delta_jB = delta4d  # b, 1, h1, w1, h2, w2 
        
        """ Original implementation
        # Reorder the indices according 
        diA = delta_iA.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)] 
        djA = delta_jA.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)]        
        diB = delta_iB.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)]
        djB = delta_jB.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)]

        # *ksize place the pixel to the 1st location in upsampled 4D-Volumn
        iA = iA * ksize + diA.expand_as(iA)
        jA = jA * ksize + djA.expand_as(jA)
        iB = iB * ksize + diB.expand_as(iB)
        jB = jB * ksize + djB.expand_as(jB)
        """
        
        # Support batches
        for ibx in range(batch_size):
            diA = delta_iA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]  # h*w
            djA = delta_jA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]
            diB = delta_iB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]
            djB = delta_jB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]
            
            iA[ibx] = iA[ibx] * ksize + diA
            jA[ibx] = jA[ibx] * ksize + djA
            iB[ibx] = iB[ibx] * ksize + diB
            jB[ibx] = jB[ibx] * ksize + djB

    xA = XA[iA.view(-1), jA.view(-1)].view(batch_size, -1)
    yA = YA[iA.view(-1), jA.view(-1)].view(batch_size, -1)
    xB = XB[iB.view(-1), jB.view(-1)].view(batch_size, -1)
    yB = YB[iB.view(-1), jB.view(-1)].view(batch_size, -1)
        
    if return_indices:
        return (jA,iA,jB,iB,score)
    else:
        return (xA,yA,xB,yB,score)    
    
def corr_to_matches_topk(corr4d, delta4d=None, topk=1, ksize=1, do_softmax=True,                     
                         invert_matching_direction=False):

    device = corr4d.device
    batch_size, ch, fs1, fs2, fs3, fs4 = corr4d.size()  # b, c, h, w, h, w

    # Index meshgrid for current resolution
    JA, IA = np.meshgrid(range(fs2), range(fs1)) 
    JB, IB = np.meshgrid(range(fs4), range(fs3))    
    JA, IA = torch.LongTensor(JA).view(1,-1).to(device), torch.LongTensor(IA).view(1,-1).to(device)
    JB, IB = torch.LongTensor(JB).view(1,-1).to(device), torch.LongTensor(IB).view(1,-1).to(device)

    if invert_matching_direction:
        nc_A_Bvec = corr4d.view(batch_size, fs1, fs2, fs3 * fs4)

        if do_softmax:
            nc_A_Bvec = torch.nn.functional.softmax(nc_A_Bvec, dim=3)

        # Max and argmax
        match_A_vals, idx_A_Bvec = torch.topk(nc_A_Bvec, topk, dim=3, largest=True, sorted=True)    
        score = match_A_vals.view(batch_size, -1)

        # Pick the indices for the best score
        iB = IB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size, -1, topk).contiguous()
        jB = JB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size, -1, topk).contiguous()
        iA = IA.unsqueeze(-1).expand_as(iB).contiguous()
        jA = JA.unsqueeze(-1).expand_as(jB).contiguous()

    else:    
        nc_B_Avec = corr4d.view(batch_size, fs1 * fs2, fs3, fs4) # [batch_idx,k_A,i_B,j_B]
        if do_softmax:
            nc_B_Avec = torch.nn.functional.softmax(nc_B_Avec, dim=1)

        match_B_vals, idx_B_Avec = torch.topk(nc_B_Avec, topk, dim=1, largest=True, sorted=True)
        score = match_B_vals.view(batch_size, -1)

        iA = IA.view(-1)[idx_B_Avec.view(-1)].view(batch_size, topk, -1).contiguous()
        jA = JA.view(-1)[idx_B_Avec.view(-1)].view(batch_size, topk, -1).contiguous() 
        iB = IB.unsqueeze(1).expand_as(iA).contiguous() 
        jB = JB.unsqueeze(1).expand_as(jA).contiguous()
        
    iA = iA.view(batch_size, -1)
    jA = jA.view(batch_size, -1)
    iB = iB.view(batch_size, -1)
    jB = jB.view(batch_size, -1)   

    if delta4d is not None: # relocalization, it is also the case ksize > 1
        # The shift within the pooling window reference to (0,0,0,0)
        delta_iA, delta_jA, delta_iB, delta_jB = delta4d

        # Support batches
        for ibx in range(batch_size):
            diA = delta_iA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]  # h*w
            djA = delta_jA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]
            diB = delta_iB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]
            djB = delta_jB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]]
            
            iA[ibx] = iA[ibx] * ksize + diA
            jA[ibx] = jA[ibx] * ksize + djA
            iB[ibx] = iB[ibx] * ksize + diB
            jB[ibx] = jB[ibx] * ksize + djB

    return (jA, iA, jB, iB, score)


In [26]:
def conv4d(data, filters, bias=None, permute_filters=True, use_half=False):
    b, c, h, w, d, t = data.size()

    data = data.permute(
        2, 0, 1, 3, 4, 5
    ).contiguous()  # permute to avoid making contiguous inside loop

    # Same permutation is done with filters, unless already provided with permutation
    if permute_filters:
        filters = filters.permute(
            2, 0, 1, 3, 4, 5
        ).contiguous()  # permute to avoid making contiguous inside loop

    c_out = filters.size(1)
    if use_half:
        output = Variable(
            torch.HalfTensor(h, b, c_out, w, d, t), requires_grad=data.requires_grad
        )
    else:
        output = Variable(
            torch.zeros(h, b, c_out, w, d, t), requires_grad=data.requires_grad
        )

    padding = filters.size(0) // 2
    if use_half:
        Z = Variable(torch.zeros(padding, b, c, w, d, t).half())
    else:
        Z = Variable(torch.zeros(padding, b, c, w, d, t))

    if data.is_cuda:
        Z = Z.cuda(data.get_device())
        output = output.cuda(data.get_device())

    data_padded = torch.cat((Z, data, Z), 0)

    for i in range(output.size(0)):  # loop on first feature dimension
        # convolve with center channel of filter (at position=padding)
        output[i, :, :, :, :, :] = F.conv3d(
            data_padded[i + padding, :, :, :, :, :],
            filters[padding, :, :, :, :, :],
            bias=bias,
            stride=1,
            padding=padding,
        )
        # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
        for p in range(1, padding + 1):
            output[i, :, :, :, :, :] = output[i, :, :, :, :, :] + F.conv3d(
                data_padded[i + padding - p, :, :, :, :, :],
                filters[padding - p, :, :, :, :, :],
                bias=None,
                stride=1,
                padding=padding,
            )
            output[i, :, :, :, :, :] = output[i, :, :, :, :, :] + F.conv3d(
                data_padded[i + padding + p, :, :, :, :, :],
                filters[padding + p, :, :, :, :, :],
                bias=None,
                stride=1,
                padding=padding,
            )

    output = output.permute(1, 2, 0, 3, 4, 5).contiguous()
    return output


class Conv4d(_ConvNd):
    """Applies a 4D convolution over an input signal composed of several input
    planes.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        bias=True,
        pre_permuted_filters=True,
    ):
        # stride, dilation and groups !=1 functionality not tested
        stride = 1
        dilation = 1
        groups = 1
        # zero padding is added automatically in conv4d function to preserve tensor size
        padding = 0
        kernel_size = _quadruple(kernel_size)
        stride = _quadruple(stride)
        padding = _quadruple(padding)
        dilation = _quadruple(dilation)

        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            False,
            _quadruple(0),
            groups,
            bias,
            padding_mode="zeros",
        )

        # weights will be sliced along one dimension during convolution loop
        # make the looping dimension to be the first one in the tensor,
        # so that we don't need to call contiguous() inside the loop
        self.pre_permuted_filters = pre_permuted_filters
        if self.pre_permuted_filters:
            self.weight.data = self.weight.data.permute(2, 0, 1, 3, 4, 5).contiguous()
        self.use_half = False

    def forward(self, input):
        return conv4d(
            input,
            self.weight,
            bias=self.bias,
            permute_filters=not self.pre_permuted_filters,
            use_half=self.use_half,
        )  # filters pre-permuted in constructor


In [27]:
# -----------------------------
# ImMatchNet (fixed)
# -----------------------------
class ImMatchNet(nn.Module):
    def __init__(self, 
                 feature_extraction_cnn='resnet101', 
                 feature_extraction_last_layer='',
                 feature_extraction_model_file=None,
                 return_correlation=False,  
                 ncons_kernel_sizes=[3,3,3],
                 ncons_channels=[10,10,1],
                 normalize_features=True,
                 train_fe=False,
                 use_cuda=True,
                 relocalization_k_size=0,
                 half_precision=False,
                 checkpoint=None):
        
        super(ImMatchNet, self).__init__()
        # Load checkpoint
        if checkpoint is not None and checkpoint != '':
            print('Loading checkpoint...')
            checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
            checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
            # override relevant parameters
            print('Using checkpoint parameters:')
            ncons_channels = checkpoint['args'].ncons_channels
            print('  ncons_channels: '+str(ncons_channels))
            ncons_kernel_sizes = checkpoint['args'].ncons_kernel_sizes
            print('  ncons_kernel_sizes: '+str(ncons_kernel_sizes))            

        self.use_cuda = use_cuda
        self.normalize_features = normalize_features
        self.return_correlation = return_correlation
        self.relocalization_k_size = relocalization_k_size
        self.half_precision = half_precision
        
        self.FeatureExtraction = FeatureExtraction(train_fe=train_fe,
                                                   feature_extraction_cnn=feature_extraction_cnn,
                                                   feature_extraction_model_file=feature_extraction_model_file,
                                                   last_layer=feature_extraction_last_layer,
                                                   normalization=normalize_features,
                                                   use_cuda=self.use_cuda)
        
        self.FeatureCorrelation = FeatureCorrelation(shape='4D', normalization=False)

        self.NeighConsensus = NeighConsensus(use_cuda=self.use_cuda,
                                             kernel_sizes=ncons_kernel_sizes,
                                             channels=ncons_channels)

        # Load weights
        if checkpoint is not None and checkpoint != '':
            print('Copying weights...')
            for name, param in self.FeatureExtraction.state_dict().items():
                if 'num_batches_tracked' not in name:
                    self.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])    
            for name, param in self.NeighConsensus.state_dict().items():
                self.NeighConsensus.state_dict()[name].copy_(checkpoint['state_dict']['NeighConsensus.' + name])
            print('Done!')
        
        self.FeatureExtraction.eval()

        if self.half_precision:
            for p in self.NeighConsensus.parameters():
                p.data = p.data.half()
            for l in self.NeighConsensus.conv:
                if isinstance(l, Conv4d):
                    l.use_half = True
                    
    def forward(self, tnf_batch): 
        feature_A = self.FeatureExtraction(tnf_batch['source_image'])
        feature_B = self.FeatureExtraction(tnf_batch['target_image'])
        if self.half_precision:
            feature_A = feature_A.half()
            feature_B = feature_B.half()
            
        corr4d = self.FeatureCorrelation(feature_A, feature_B)

        if self.relocalization_k_size > 1:
            corr4d, max_i, max_j, max_k, max_l = maxpool4d(corr4d, k_size=self.relocalization_k_size)

        corr4d = MutualMatching(corr4d)
        corr4d = self.NeighConsensus(corr4d)
        corr4d = MutualMatching(corr4d)
        
        if self.relocalization_k_size > 1:
            delta4d = (max_i, max_j, max_k, max_l)
            return (corr4d, delta4d)
        else:
            return corr4d
 
    def forward_feat(self, featA, featB, normalize=True): 
        if normalize:
            feature_A = featureL2Norm(featA)
            feature_B = featureL2Norm(featB)
        else:
            feature_A = featA
            feature_B = featB
        if self.half_precision:
            feature_A = feature_A.half()
            feature_B = feature_B.half()

        corr4d = self.FeatureCorrelation(feature_A, feature_B)
        if self.relocalization_k_size > 1:
            corr4d, max_i, max_j, max_k, max_l = maxpool4d(corr4d, k_size=self.relocalization_k_size)
        corr4d = MutualMatching(corr4d)
        corr4d = self.NeighConsensus(corr4d)
        corr4d = MutualMatching(corr4d)
        if self.relocalization_k_size > 1:
            delta4d = (max_i, max_j, max_k, max_l)
            return (corr4d, delta4d)
        else:
            return corr4d