New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BC-Breaking] Avoid moving resampling kernel device and dtype moves #1514
[BC-Breaking] Avoid moving resampling kernel device and dtype moves #1514
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
604cc27
to
a041c3f
Compare
@@ -1360,7 +1362,8 @@ def _get_sinc_resample_kernel( | |||
# they will have a lot of almost zero values to the left or to the right... | |||
# There is probably a way to evaluate those filters more efficiently, but this is kept for | |||
# future work. | |||
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64) | |||
idx_dtype = dtype if dtype is not None else torch.float64 | |||
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if someone passes a low precision type like uint8? I think it might be better to pick whatever dtype is most efficient for this operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following offline discussion, we can keep higher precision type float64 because the kernel computation is a one-time computation, whose dimensions are limited to roughly orig_freq // gcd x new_freq // gcd. normal resampling frequencies will generally have large gcd, in which case dtype differences will have minor computation differences
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width | ||
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale) | ||
if dtype is None: | ||
kernels = kernels.to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to just return the kernel and do the dtype and device cast after the callsite, since you're not using dtype outside of arange.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following offline discussion, it is fine to convert to this generally "default" type prior to returning the kernel to the call from transforms
Maybe we can override |
* Update rpc_ddp_tutorial.rst Replace `RRef` with `RemoteModule` in some descriptions. * Update rpc_ddp_tutorial.rst typo fix Co-authored-by: Holly Sweeney <77758406+holly1238@users.noreply.github.com>
Initially, the kernel used for resampling was computed only after it was fed a waveform, and it would be initialized to be of the same device and dtype as the input waveform that is resampled. In a later change, we cached the resampling kernel in
transforms.Resample
, since kernel computation is significant and redundant when using a given set of resampling parameters.To maintain the original behavior would require moving the cached kernel to the device and dtype of the input waveform every time
transforms.Resample
is called, and this could result in unintended side effects; running a CPU transforms on CUDA inputs would have overhead from constantly moving a cached CPU kernel to CUDA, and the expectation is that the user will manually move the transforms themself to the correct device and dtype (ex/resample = transforms.Resample; resample = resample.to(device=torch.device('cuda'), dtype=torch.float16
). This PR removes the moving of the cached kernel to be of the correct device and dtype at every call to it, which will now throw an error if a user does not move the transform to CUDA but calls the function on a CUDA waveform.This PR additionally results in slight differences in results, because of the precision of the kernel. In the previous/functional implementation, the kernel computation is done in the dtype corresponding to the waveform from the start, but in the new transforms implementation, the kernel computation will be done in
float64
before being moved and cached asfloat32
in__init__
, prior to the user moving it to the correct dtype themself. This results in higher precision resampling using transforms when resampling on waveforms of dtype smaller thanfloat32
, and slightly lower precision resampling on waveforms of dtype greater thanfloat32
, since it was intermediately cached asfloat32
.cc #1487