Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F
import torch.nn.parallel as dp
from torch.nn.utils import clip_grad_norm
from torch.autograd import Variable
from torch.autograd import Variable, gradcheck
from torch.nn import Parameter
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
Expand Down Expand Up @@ -625,6 +625,17 @@ def test_Dropout3d(self):
input = torch.Tensor(num_features, b, d, w, h)
self._test_dropout(nn.Dropout3d, input)

def test_pad(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
gradcheck(lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,))
gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,))

inputs = Variable(torch.randn(1, 2, 3, 4, 4), requires_grad=True)
gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,))

This comment was marked as off-topic.


def _test_maxpool_indices(self, num_dim, type=torch.FloatTensor):
def expected_indices(dim):
if dim == 1:
Expand Down
72 changes: 72 additions & 0 deletions torch/nn/_functions/padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from torch.autograd import Function


class ConstantPad2d(Function):

def __init__(self, pad, value=0):
super(ConstantPad2d, self).__init__()
self.pad = pad
self.value = value

def forward(self, input):
assert input.dim() == 4, 'only 4D supported for padding'
pad_l, pad_r, pad_t, pad_b = self.pad
h = input.size(2) + pad_t + pad_b
w = input.size(3) + pad_l + pad_r
assert w > 0 and h > 0, 'input is too small'

self.input_size = input.size()

# crop input if necessary
output = input.new(input.size(0), input.size(1), h, w).fill_(self.value)
c_input = input
if pad_t < 0:
c_input = c_input.narrow(2, -pad_t, c_input.size(2) + pad_t)
if pad_b < 0:
c_input = c_input.narrow(2, 0, c_input.size(2) + pad_b)
if pad_l < 0:
c_input = c_input.narrow(3, -pad_l, c_input.size(3) + pad_l)
if pad_r < 0:
c_input = c_input.narrow(3, 0, c_input.size(3) + pad_r)

# crop output if necessary
c_output = output
if pad_t > 0:
c_output = c_output.narrow(2, pad_t, c_output.size(2) - pad_t)
if pad_b > 0:
c_output = c_output.narrow(2, 0, c_output.size(2) - pad_b)
if pad_l > 0:
c_output = c_output.narrow(3, pad_l, c_output.size(3) - pad_l)
if pad_r > 0:
c_output = c_output.narrow(3, 0, c_output.size(3) - pad_r)
c_output.copy_(c_input)
return output

def backward(self, grad_output):
pad_l, pad_r, pad_t, pad_b = self.pad

grad_input = grad_output.new(self.input_size).zero_()

# crop grad_input if necessary
cg_input = grad_input
if pad_t < 0:
cg_input = cg_input.narrow(2, -pad_t, cg_input.size(2) + pad_t)
if pad_b < 0:
cg_input = cg_input.narrow(2, 0, cg_input.size(2) + pad_b)
if pad_l < 0:
cg_input = cg_input.narrow(3, -pad_l, cg_input.size(3) + pad_l)
if pad_r < 0:
cg_input = cg_input.narrow(3, 0, cg_input.size(3) + pad_r)

# crop grad_output if necessary
cg_output = grad_output
if pad_t > 0:
cg_output = cg_output.narrow(2, pad_t, cg_output.size(2) - pad_t)
if pad_b > 0:
cg_output = cg_output.narrow(2, 0, cg_output.size(2) - pad_b)
if pad_l > 0:
cg_output = cg_output.narrow(3, pad_l, cg_output.size(3) - pad_l)
if pad_r > 0:
cg_output = cg_output.narrow(3, 0, cg_output.size(3) - pad_r)
cg_input.copy_(cg_output)
return grad_input
34 changes: 34 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import _functions
from .modules import utils
from torch.nn._functions.conv import ConvNd
from ._functions.padding import ConstantPad2d
from .modules.utils import _single, _pair, _triple
# Convolutions

Expand Down Expand Up @@ -531,3 +532,36 @@ def upsample_bilinear(input, size=None, scale_factor=None):
scale_factor (int): multiplier for spatial size. Has to be an integer.
"""
return _functions.thnn.UpsamplingBilinear2d(size, scale_factor)(input)


def pad(input, pad, mode='constant', value=0):
"""Pads tensor.

Currently only 2D and 3D padding supported.
In case of 4D input tensor pad should be in form (pad_l, pad_r, pad_t, pad_b )
In case of 5D pad should be (pleft, pright, ptop, pbottom, pfront, pback)

Args
input (Variable): 4D or 5D tensor
pad (tuple): 4-elem or 6-elem tuple
mode: 'constant', 'reflect' or 'replicate'
value: fill value for 'constant' padding
"""
if input.dim() == 4:
assert len(pad) == 4, '4D tensors expect 4 values for padding'
if mode == 'constant':
return ConstantPad2d(pad, value)(input)
elif mode == 'reflect':
return _functions.thnn.ReflectionPad2d(*pad)(input)
elif mode == 'replicate':
return _functions.thnn.ReplicationPad2d(*pad)(input)
elif input.dim() == 5:
assert len(pad) == 6, '5D tensors expect 6 values for padding'
if mode == 'constant':
raise NotImplementedError
elif mode == 'reflect':
raise NotImplementedError
elif mode == 'replicate':
return _functions.thnn.ReplicationPad3d(*pad)(input)
else:
raise NotImplementedError("Only 4D and 5D padding is supported for now")