Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposed recompute_scale_factor into nn.Upsample #66419

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 15 additions & 12 deletions test/test_nn.py
Expand Up @@ -10505,19 +10505,22 @@ def test_upsamplingNearest1d(self):

def test_upsamplingLinear1d(self):
for align_corners in [True, False]:
kwargs = dict(mode='linear', align_corners=align_corners)

# test float scale factor up & downsampling
for scale_factor in [0.5, 1.5, 2]:
m = nn.Upsample(scale_factor=scale_factor, **kwargs)
in_t = torch.ones(1, 1, 2)
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
with warnings.catch_warnings(record=True) as w:
out_t = m(in_t)
self.assertEqual(torch.ones(1, 1, out_size), out_t.data)
for recompute_scale_factor in [True, False]:
kwargs = dict(
mode='linear', align_corners=align_corners, recompute_scale_factor=recompute_scale_factor
)
# test float scale factor up & downsampling
for scale_factor in [0.5, 1.5, 2]:
m = nn.Upsample(scale_factor=scale_factor, **kwargs)
in_t = torch.ones(1, 1, 2)
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
with warnings.catch_warnings(record=True) as w:
out_t = m(in_t)
self.assertEqual(torch.ones(1, 1, out_size), out_t.data)

input = torch.randn(1, 1, 2, requires_grad=True)
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), (input,))
input = torch.randn(1, 1, 2, requires_grad=True)
if not recompute_scale_factor:
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), (input,))
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

def test_upsamplingLinear1d_spatial_invariance(self):
m = nn.Upsample(scale_factor=3, mode='linear', align_corners=False)
Expand Down
16 changes: 14 additions & 2 deletions torch/nn/modules/upsampling.py
Expand Up @@ -32,6 +32,14 @@ class Upsample(Module):
and output tensors are aligned, and thus preserving the values at
those pixels. This only has effect when :attr:`mode` is
``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``
recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
interpolation calculation. If `recompute_scale_factor` is ``True``, then
`scale_factor` must be passed in and `scale_factor` is used to compute the
output `size`. The computed output `size` will be used to infer new scales for
the interpolation. Note that when `scale_factor` is floating-point, it may differ
from the recomputed `scale_factor` due to rounding and precision issues.
If `recomputed_scale_factor` is ``False``, then `size` or `scale_factor` will
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
be used directly for interpolation.

Shape:
- Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
Expand Down Expand Up @@ -124,9 +132,11 @@ class Upsample(Module):
scale_factor: Optional[_ratio_any_t]
mode: str
align_corners: Optional[bool]
recompute_scale_factor: Optional[bool]
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None,
mode: str = 'nearest', align_corners: Optional[bool] = None) -> None:
mode: str = 'nearest', align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None) -> None:
super(Upsample, self).__init__()
self.name = type(self).__name__
self.size = size
Expand All @@ -136,9 +146,11 @@ def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_r
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.align_corners = align_corners
self.recompute_scale_factor = recompute_scale_factor

def forward(self, input: Tensor) -> Tensor:
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,
recompute_scale_factor=self.recompute_scale_factor)

def extra_repr(self) -> str:
if self.scale_factor is not None:
Expand Down