-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 Feature
Enable PyTorch to compute convolutional features given:
- A collection of batches of images of shape (b_i, b_j, c, h, w)
- A batch of kernel weights of shape (b_i, out_features, in_features, kernel_height, kernel_width)
Each batch in a collection is convolved only with its respective parameters. This can be currently achieved with the existing PyTorch operations but is rather slow compared to what a CUDA implementation could do.
Motivation
Meta-learning is slowly becoming a very popular framework for tackling problems. In almost every instance of meta-learning, there exists an inner loop optimization process, where a network, is updated, multiple times, given a task. There exists a batch of tasks, so currently, one has to process each task sequentially. This is necessary because the weights of a network change within the inner loop optimization process, thus doing this in parallel would require a convolution operator that can receive a batch of parameter kernels to be convolved with a batch of tasks.
Furthermore, genetic algorithms, and RL, often require similar operations, where each individual in a population changes it's inner state with each update, hence, requiring a similar mechanism.
Finally, in cases where Hyper Networks are used, if one chose to have a sample conditional weight matrix for a particular operation, then a Batch Convolutional operator would be necessary.
I think it's in PyTorch's best interests to have such a feature available, to get ahead of the curve.
Pitch
I have already implemented a variant of what I am proposing. When I am using a large collection size, then my method out performs sequential convolutions by a factor of 4x. However, when my collection size is close to 1, hence, effectively being a normal convolutional layer, my method is 20x slower than the default CUDA-based convolutional implementation. I need help by someone who can either write such CUDA-based implementations, or point me to the right direction so I can do it.
Find my existing implementation on this attached
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class BatchConvLayer(nn.Module):
def __init__(self, input_shape, kernel_size, stride, padding, dilation, groups):
super(BatchConvLayer, self).__init__()
self.input_shape = input_shape
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.build_block()
def image_to_patch_coordinates(self, input_image, patch_size, stride, dilation):
patches = []
for i in range(0, input_image.shape[2] - (patch_size[0] - 1), stride):
for j in range(0, input_image.shape[3] - (patch_size[1] - 1), stride):
x_range = (patch_size[0] + ((patch_size[0] - 1) * (dilation - 1)))
y_range = (patch_size[1] + ((patch_size[1] - 1) * (dilation - 1)))
#print(x_range, y_range)
if i + x_range <= input_image.shape[2] and j + y_range <= input_image.shape[3]:
coord = torch.Tensor([x * input_image.shape[2] + y for x in range(i, i + x_range, dilation) for y in
range(j, j + y_range, dilation)])
patches.append(coord)
indexes_of_patches = torch.stack(patches, dim=0).long()
return indexes_of_patches
def image_to_vectors(self, input_image, indexes_of_patches):
p, num_pixels = indexes_of_patches.shape
indexes_of_patches = indexes_of_patches.view(p * num_pixels)
out = input_image.view(input_image.shape[0], input_image.shape[1], -1)
#print(out.shape)
out = torch.index_select(out, dim=2,
index=indexes_of_patches.to(input_image.device))
out = out.view(out.shape[0], out.shape[1], p, num_pixels)
return out
def build_block(self):
x = torch.zeros(self.input_shape)
if type(self.kernel_size) is int:
self.kernel_size = (self.kernel_size, self.kernel_size)
out = x.view(-1, x.shape[2], x.shape[3], x.shape[4])
if self.padding > 0:
out = F.pad(out, pad=[self.padding, self.padding, self.padding, self.padding], mode='constant', value=0)
self.indexes_of_patches = self.image_to_patch_coordinates(input_image=out, patch_size=self.kernel_size, stride=self.stride, dilation=self.dilation).to(x.device)
x_vectors = self.image_to_vectors(out, indexes_of_patches=self.indexes_of_patches)
self.spatial_shape = None
print('block built', x_vectors.shape)
def forward(self, x, weights, biases=None):
# assert x.shape[0] == weights.shape[0], \
# "The batch size of the input images needs to be equal to the batch size of the params"
# if biases is not None:
# assert weights.shape[0] == biases.shape[0], \
# "The batch size of the weight parameters needs to be equal to the batch size of the bias parameters"
weights = weights.permute([0, 2, 3, 4, 1]).view(weights.shape[0], -1, weights.shape[1])
b_out, b_in = x.shape[:2]
out = x.view(-1, x.shape[2], x.shape[3], x.shape[4])
if self.padding > 0:
out = F.pad(out, pad=[self.padding, self.padding, self.padding, self.padding], mode='constant', value=0)
out = self.image_to_vectors(out, indexes_of_patches=self.indexes_of_patches)
out = torch.cat(out.unbind(1), 2)
out = out.view(b_out, b_in * out.shape[1], out.shape[2])
out = torch.bmm(out, weights)
if biases is not None:
out = out + biases.unsequeeze(1)
out = out.view(b_out, b_in, -1, out.shape[-1])
if self.spatial_shape is None:
self.spatial_shape = int(np.floor(np.sqrt(out.shape[2])))
out = out.view(b_out, b_in, self.spatial_shape, self.spatial_shape, out.shape[-1])
out = out.permute([0, 1, 4, 2, 3])
return out