Skip to content

Batch Convolutional Layers - Similar to torch.bmm but for convolutional operators #17983

@AntreasAntoniou

Description

@AntreasAntoniou

🚀 Feature

Enable PyTorch to compute convolutional features given:

  • A collection of batches of images of shape (b_i, b_j, c, h, w)
  • A batch of kernel weights of shape (b_i, out_features, in_features, kernel_height, kernel_width)

Each batch in a collection is convolved only with its respective parameters. This can be currently achieved with the existing PyTorch operations but is rather slow compared to what a CUDA implementation could do.

Motivation

Meta-learning is slowly becoming a very popular framework for tackling problems. In almost every instance of meta-learning, there exists an inner loop optimization process, where a network, is updated, multiple times, given a task. There exists a batch of tasks, so currently, one has to process each task sequentially. This is necessary because the weights of a network change within the inner loop optimization process, thus doing this in parallel would require a convolution operator that can receive a batch of parameter kernels to be convolved with a batch of tasks.

Furthermore, genetic algorithms, and RL, often require similar operations, where each individual in a population changes it's inner state with each update, hence, requiring a similar mechanism.

Finally, in cases where Hyper Networks are used, if one chose to have a sample conditional weight matrix for a particular operation, then a Batch Convolutional operator would be necessary.

I think it's in PyTorch's best interests to have such a feature available, to get ahead of the curve.

Pitch

I have already implemented a variant of what I am proposing. When I am using a large collection size, then my method out performs sequential convolutions by a factor of 4x. However, when my collection size is close to 1, hence, effectively being a normal convolutional layer, my method is 20x slower than the default CUDA-based convolutional implementation. I need help by someone who can either write such CUDA-based implementations, or point me to the right direction so I can do it.

Find my existing implementation on this attached

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class BatchConvLayer(nn.Module):
    def __init__(self, input_shape, kernel_size, stride, padding, dilation, groups):
        super(BatchConvLayer, self).__init__()
        self.input_shape = input_shape
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.build_block()

    def image_to_patch_coordinates(self, input_image, patch_size, stride, dilation):
        patches = []

        for i in range(0, input_image.shape[2] - (patch_size[0] - 1), stride):
            for j in range(0, input_image.shape[3] - (patch_size[1] - 1), stride):
                x_range = (patch_size[0] + ((patch_size[0] - 1) * (dilation - 1)))
                y_range = (patch_size[1] + ((patch_size[1] - 1) * (dilation - 1)))
                #print(x_range, y_range)
                if i + x_range <= input_image.shape[2] and j + y_range <= input_image.shape[3]:
                    coord = torch.Tensor([x * input_image.shape[2] + y for x in range(i, i + x_range, dilation) for y in
                                      range(j, j + y_range, dilation)])


                patches.append(coord)

        indexes_of_patches = torch.stack(patches, dim=0).long()
        return indexes_of_patches

    def image_to_vectors(self, input_image, indexes_of_patches):
        p, num_pixels = indexes_of_patches.shape
        indexes_of_patches = indexes_of_patches.view(p * num_pixels)



        out = input_image.view(input_image.shape[0], input_image.shape[1], -1)
        #print(out.shape)
        out = torch.index_select(out, dim=2,
                                 index=indexes_of_patches.to(input_image.device))
        out = out.view(out.shape[0], out.shape[1], p, num_pixels)

        return out

    def build_block(self):
        x = torch.zeros(self.input_shape)

        if type(self.kernel_size) is int:
            self.kernel_size = (self.kernel_size, self.kernel_size)

        out = x.view(-1, x.shape[2], x.shape[3], x.shape[4])


        if self.padding > 0:
            out = F.pad(out, pad=[self.padding, self.padding, self.padding, self.padding], mode='constant', value=0)

        self.indexes_of_patches = self.image_to_patch_coordinates(input_image=out, patch_size=self.kernel_size, stride=self.stride, dilation=self.dilation).to(x.device)

        x_vectors = self.image_to_vectors(out, indexes_of_patches=self.indexes_of_patches)
        self.spatial_shape = None
        print('block built', x_vectors.shape)

    def forward(self, x, weights, biases=None):

        # assert x.shape[0] == weights.shape[0], \
        #     "The batch size of the input images needs to be equal to the batch size of the params"
        # if biases is not None:
        #     assert weights.shape[0] == biases.shape[0], \
        #     "The batch size of the weight parameters needs to be equal to the batch size of the bias parameters"
        weights = weights.permute([0, 2, 3, 4, 1]).view(weights.shape[0], -1, weights.shape[1])

        b_out, b_in = x.shape[:2]
        out = x.view(-1, x.shape[2], x.shape[3], x.shape[4])

        if self.padding > 0:
            out = F.pad(out, pad=[self.padding, self.padding, self.padding, self.padding], mode='constant', value=0)


        out = self.image_to_vectors(out, indexes_of_patches=self.indexes_of_patches)

        out = torch.cat(out.unbind(1), 2)
        out = out.view(b_out, b_in * out.shape[1], out.shape[2])

        out = torch.bmm(out, weights)


        if biases is not None:
            out = out + biases.unsequeeze(1)

        out = out.view(b_out, b_in, -1, out.shape[-1])

        if self.spatial_shape is None:
            self.spatial_shape = int(np.floor(np.sqrt(out.shape[2])))

        out = out.view(b_out, b_in, self.spatial_shape, self.spatial_shape, out.shape[-1])

        out = out.permute([0, 1, 4, 2, 3])

        return out

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions