Deformable convolution

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np
from tqdm import trange
import torchvision

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test_dataset, batch_size=64, shuffle=False)

print(mnist_train_dataset.data.shape)

torch.Size([60000, 28, 28])


In [3]:
class reg_conv_model(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, bias=True):
        super(reg_conv_model, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = padding , dilation = dilation, bias = bias)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(196, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [4]:
# training loop
model = reg_conv_model(1, 1, 3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'epoch {epoch}, loss {loss.item()}')

epoch 0, loss 0.16892126202583313
epoch 1, loss 0.10327652841806412
epoch 2, loss 0.15031404793262482
epoch 3, loss 0.49475234746932983
epoch 4, loss 0.48565202951431274


In [6]:
# test
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
    print(f'accuracy {correct / total}')

accuracy 0.9217


In [3]:
class expand_layer(nn.Module):
    def __init__(self, p0, pos, conv_kernel_size):
        super(expand_layer, self).__init__()
        self.P0x, self.P0y = p0
        self.i, self.j = pos

        self.Pnx = self.i - (conv_kernel_size - 1)/2
        self.Pny = (conv_kernel_size - 1)/2 - self.j

    def forward(self, x, offset):
        offset = offset[self.i, self.j]
        
        Px = self.P0x + self.Pnx + offset[self.i, self.j, 0]
        Py = self.P0y + self.Pny + offset[self.i, self.j, 1]

        Px0 = int(Px)
        Py0 = int(Py)
        Px1 = Px0 + 1
        Py1 = Py0 + 1

        xp = 0
        if Px0 >= 0 and Py0 >= 0 and Px0 < x.shape[1] and Py0 < x.shape[2]:
            xp += (1-torch.abs(Px-Px0))*(1-torch.abs(Py-Py0))*x[:, Px0, Py0]
        if Px0 >= 0 and Py1 >= 0 and Px0 < x.shape[1] and Py1 < x.shape[2]:
            xp += (1-torch.abs(Px-Px0))*(1-torch.abs(Py-Py1))*x[:, Px0, Py1]
        if Px1 >= 0 and Py0 >= 0 and Px1 < x.shape[1] and Py0 < x.shape[2]:
            xp += (1-torch.abs(Px-Px1))*(1-torch.abs(Py-Py0))*x[:, Px1, Py0]
        if Px1 >= 0 and Py1 >= 0 and Px1 < x.shape[1] and Py1 < x.shape[2]:
            xp += (1-torch.abs(Px-Px1))*(1-torch.abs(Py-Py1))*x[:, Px1, Py1]
    
        return xp


class def_conv_layer(nn.Module):
    def __init__(self, in_channels, out_channels, spatial_size, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(def_conv_layer, self).__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding

        self.offset_conv = nn.Conv2d(in_channels = in_channels, out_channels = 2 * kernel_size * kernel_size, kernel_size = 3, stride = 1, padding = 1, dilation = 1, bias = bias)

        self.expand_layer_list = nn.ModuleList()
        for px in range(spatial_size[0]):
            for i in range(kernel_size):
                for py in range(spatial_size[1]):
                    for j in range(kernel_size):
                        self.expand_layer_list.append(expand_layer((px, py), (i, j), kernel_size))

        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = padding , dilation = dilation, bias = bias)

    def forward(self, x):
        offset = self.offset_conv(x)
        offset = torch.reshape(offset, (x.shape[0], x.shape[2], x.shape[3], self.kernel_size, self.kernel_size, 2))
        expnaded_x = torch.zeros((x.shape[0], self.in_channels, x.shape[2]*self.kernel_size, x.shape[3]*self.kernel_size))
        for batch in range(x.shape[0]):
            for i in range(x.shape[2]*self.kernel_size):
                for j in range(x.shape[3]*self.kernel_size):
                    expnaded_x[batch, :, i, j] = self.expand_layer_list[i*x.shape[3]*self.kernel_size+j](x[batch,:,:,:], offset[batch,:,:,:,:])
        
        return self.conv(expnaded_x)

class def_conv_model(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(def_conv_model, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.conv = def_conv_layer(in_channels, 1, (28, 28), kernel_size, stride, padding, dilation, bias)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(196, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x
        

In [4]:
# training loop
model = def_conv_model(1, 1, 3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in trange(5):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'epoch {epoch}, loss {loss.item()}')

  0%|          | 0/10 [09:30<?, ?it/s]


TypeError: conv2d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, bool), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !bool!)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !bool!)


In [None]:
# test
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.shape[0]
        correct += int((predicted == labels).sum())
    print(f'accuracy {correct / total}')