Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UpSample-nearest cuda kernel update #21694

Closed
wants to merge 5 commits into from

Conversation

jjsjann123
Copy link
Collaborator

updating upsampling kernel:

  1. avoids atomicAdd for better fp16 performance.
  2. better launch configures for 2D input.

updating upsampling kernel:
1. avoids atomicAdd for better fp16 performance.
2. better launch configures for 2D input.
@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Jun 12, 2019
@jjsjann123
Copy link
Collaborator Author

Perf number/scripts will be posted shortly.
I'll have another PR for bilinear upsampling coming soon as well.

cc @ngimel

@jjsjann123
Copy link
Collaborator Author

image

fp16 forward perf number has been observed to be all over the place, especially for tiny input. :/
Things to notice here is the speedup on the fp16 backward path.

Here's the script for the benchmark

import torch
import numpy as np

nrep = 300
sample_mode = "nearest"

def bench(size, fn, factor, half=False, b=2, c=32, dim=2):
   x=torch.ones([b*c*(size**dim)], device='cuda', dtype = torch.float)
   if half:
     x = x.half()
   if dim==1:
     x=x.view(b, c, size).requires_grad_()
   elif dim==2:
     x=x.view(b, c, size, size).requires_grad_()
   elif dim==3:
     x=x.view(b, c, size, size, size).requires_grad_()
   torch.cuda.synchronize()
   import time
   start = time.time()
   for i in range(nrep):
      out = fn(x, scale_factor=factor, mode=sample_mode)
   torch.cuda.synchronize()
   end = time.time()
   inp_size = x.size()
   out_size = out.size()
   del x, out
   return ((end-start)/nrep, inp_size, out_size)

def bench_back(size, fn, factor, half=False, b=2, c=32, dim=2):
   x=torch.ones([b*c*(size**dim)], device='cuda', dtype = torch.float)
   if half:
     x = x.half()
   if dim==1:
     x=x.view(b, c, size).requires_grad_()
   elif dim==2:
     x=x.view(b, c, size, size).requires_grad_()
   elif dim==3:
     x=x.view(b, c, size, size, size).requires_grad_()
   torch.cuda.synchronize()
   out = fn(x, scale_factor=factor, mode=sample_mode)
   grad = torch.randn_like(out)
   import time
   start = time.time()
   for i in range(nrep):
      out.backward(grad)
   torch.cuda.synchronize()
   end = time.time()
   inp_size = x.size()
   out_size = out.size()
   del x, out, grad
   return ((end-start)/nrep, inp_size, out_size)


spatial_size = [2**i for i in range(5,12)]
batch = [8]
channel = [32]
dim = [1, 2, 3]
scale_factor = [2]
bool_flag = [False, True]
cap = 2**31

from itertools import product
for d, b, c, s, f, half_flag in product(dim, batch, channel, spatial_size, scale_factor, bool_flag):
  if ((s*f)**d)*b*c*(2 if half_flag else 4) < cap:
    (fw_time, inp_size, out_size) = bench(s, torch.nn.functional.interpolate, f, half_flag, b, c, d)
    (bw_time, inp_size, out_size) = bench_back(s, torch.nn.functional.interpolate, f, half_flag, b, c, d)
    print(inp_size, f, half_flag, 1./fw_time, 1./bw_time)


@jjsjann123
Copy link
Collaborator Author

Removed the specialized 2d kernel, as the speedup is sparse. Caching seems to have done a great job saving the memory accessing pattern.

I don't think I can justify having a dedicated kernel there 😢

@ezyang ezyang requested a review from ngimel June 14, 2019 17:43
@ezyang
Copy link
Contributor

ezyang commented Jun 14, 2019

@ngimel happy to merge this if you give it the OK

@zhangguanheng66 zhangguanheng66 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 14, 2019
@ezyang ezyang self-requested a review June 14, 2019 18:40
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure there are checks against empty tensors, and make sure you are not excessively zeroing outputs. Those already might be somewhere in the code, and I might be blind, in which case it is good to go.

aten/src/ATen/native/cuda/UpSample.cuh Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/UpSampleNearest1d.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/UpSampleNearest1d.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/UpSampleNearest2d.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/UpSampleNearest2d.cu Outdated Show resolved Hide resolved
@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Jun 17, 2019

Addressed review comments. Should be good to go when test passes

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review comments are addressed, great job, Jie!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 18, 2019
Summary:
updating upsampling kernel:
1. avoids atomicAdd for better fp16 performance.
2. better launch configures for 2D input.
Pull Request resolved: pytorch/pytorch#21694

Differential Revision: D15875791

Pulled By: ezyang

fbshipit-source-id: 426fc5d5f0c0cdf58bfa1a2b564f17a9ea286fa4
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in c471a63.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cuda Related to torch.cuda, and CUDA support in general open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants