Skip to content

Commit

Permalink
Add flexible bilinear upsampling aspect ratio redux (#1317)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgiessel authored and soumith committed May 3, 2017
1 parent e9953c4 commit 2e7635b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 11 deletions.
12 changes: 12 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2906,6 +2906,18 @@ def add_test(test):
input_size=(1, 2, 4, 4),
desc='scale'
),
dict(
module_name='UpsamplingBilinear2d',
constructor_args=(None, (2, 2)),
input_size=(1, 2, 4, 4),
desc='scale_tuple_shared'
),
dict(
module_name='UpsamplingBilinear2d',
constructor_args=(None, (2, 1)),
input_size=(1, 2, 4, 4),
desc='scale_tuple_skewed'
),
dict(
module_name='AdaptiveMaxPool1d',
constructor_args=(3,),
Expand Down
33 changes: 26 additions & 7 deletions torch/nn/_functions/thnn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch._thnn import type2backend

from . import _all_functions
from ...modules.utils import _pair
from ...functional import _check_bilinear_2d_scale_factor


class _UpsamplingBase(Function):
Expand All @@ -12,16 +14,22 @@ def __init__(self, size=None, scale_factor=None):
super(_UpsamplingBase, self).__init__()
if size is None and scale_factor is None:
raise ValueError('either size or scale_factor should be defined')
if scale_factor is not None and not isinstance(scale_factor, Integral):
raise ValueError('scale_factor must be of integer type')
if size is not None and not isinstance(size, tuple):
size = (size, size)
if scale_factor is not None and not isinstance(scale_factor, (Integral, tuple)):
raise ValueError('scale_factor must be of integer type or tuple of integer types')
self.size = size
self.scale_factor = scale_factor


class UpsamplingNearest2d(_UpsamplingBase):

def __init__(self, size=None, scale_factor=None):
super(UpsamplingNearest2d, self).__init__(size, scale_factor)

if self.scale_factor is not None and not isinstance(scale_factor, Integral):
raise ValueError('scale_factor must be of integer type for nearest neighbor sampling')

self.size = _pair(self.size) if self.size is not None else None

def forward(self, input):
assert input.dim() == 4

Expand Down Expand Up @@ -64,13 +72,21 @@ def backward(self, grad_output):

class UpsamplingBilinear2d(_UpsamplingBase):

def __init__(self, size=None, scale_factor=None):
super(UpsamplingBilinear2d, self).__init__(size, scale_factor)

if self.scale_factor is not None:
self.scale_factor = _check_bilinear_2d_scale_factor(self.scale_factor)

self.size = _pair(self.size) if self.size is not None else None

def forward(self, input):
assert input.dim() == 4

if self.scale_factor:
if self.scale_factor is not None:
self.output_size = (
input.size(2) * self.scale_factor,
input.size(3) * self.scale_factor,
input.size(2) * self.scale_factor[0],
input.size(3) * self.scale_factor[1],
)
else:
self.output_size = self.size
Expand Down Expand Up @@ -106,6 +122,9 @@ def backward(self, grad_output):
)
return grad_input

def __setstate__(self, state):
self.__dict__.update(state)
self.scale_factor = _tuple(self.scale_factor)

_all_functions.append(UpsamplingNearest2d)
_all_functions.append(UpsamplingBilinear2d)
16 changes: 15 additions & 1 deletion torch/nn/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Functional interface"""

from numbers import Integral

import torch
from . import _functions
from .modules import utils
Expand Down Expand Up @@ -610,11 +612,23 @@ def upsample_bilinear(input, size=None, scale_factor=None):
Args:
input (Variable): input
size (int or Tuple[int, int]): output spatial size.
scale_factor (int): multiplier for spatial size. Has to be an integer.
scale_factor (int or Tuple[int, int]): multiplier for spatial size
"""
return _functions.thnn.UpsamplingBilinear2d(size, scale_factor)(input)


def _check_bilinear_2d_scale_factor(scale_factor):
scale_factor = _pair(scale_factor)
try:
assert len(scale_factor) == 2
assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
except AssertionError as e:
raise ValueError('scale_factor must be a non-negative integer, '
'or a tuple of non-negative integers for bilinear upsamplings, but got: '
'{}'.format(scale_factor))
return scale_factor


def pad(input, pad, mode='constant', value=0):
"""Pads tensor.
Expand Down
19 changes: 16 additions & 3 deletions torch/nn/modules/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def __init__(self, size=None, scale_factor=None):
super(_UpsamplingBase, self).__init__()
if size is None and scale_factor is None:
raise ValueError('either size or scale_factor should be defined')
if scale_factor is not None and not isinstance(scale_factor, Integral):
raise ValueError('scale_factor must be of integer type')
self.size = _pair(size)
if scale_factor is not None and not isinstance(scale_factor, (Integral, tuple)):
raise ValueError('scale_factor must be of integer type or tuple of integer types')
self.size = size
self.scale_factor = scale_factor

def __repr__(self):
Expand Down Expand Up @@ -65,6 +65,12 @@ class UpsamplingNearest2d(_UpsamplingBase):
"""

def __init__(self, size=None, scale_factor=None):
super(UpsamplingNearest2d, self).__init__(size, scale_factor)
if self.scale_factor is not None and not isinstance(scale_factor, Integral):
raise ValueError('scale_factor must be of integer type for neighest neighbor sampling')
self.size = _pair(self.size) if self.size is not None else None

def forward(self, input):
return F.upsample_nearest(input, self.size, self.scale_factor)

Expand Down Expand Up @@ -110,5 +116,12 @@ class UpsamplingBilinear2d(_UpsamplingBase):
"""

def __init__(self, size=None, scale_factor=None):
super(UpsamplingBilinear2d, self).__init__(size, scale_factor)

if self.scale_factor is not None:
self.scale_factor = F._check_bilinear_2d_scale_factor(self.scale_factor)
self.size = _pair(self.size) if self.size is not None else None

def forward(self, input):
return F.upsample_bilinear(input, self.size, self.scale_factor)

0 comments on commit 2e7635b

Please sign in to comment.