Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-Breaking] Avoid moving resampling kernel device and dtype moves #1514

Merged
merged 1 commit into from May 19, 2021

Conversation

carolineechen
Copy link
Contributor

@carolineechen carolineechen commented May 18, 2021

Initially, the kernel used for resampling was computed only after it was fed a waveform, and it would be initialized to be of the same device and dtype as the input waveform that is resampled. In a later change, we cached the resampling kernel in transforms.Resample, since kernel computation is significant and redundant when using a given set of resampling parameters.

To maintain the original behavior would require moving the cached kernel to the device and dtype of the input waveform every time transforms.Resample is called, and this could result in unintended side effects; running a CPU transforms on CUDA inputs would have overhead from constantly moving a cached CPU kernel to CUDA, and the expectation is that the user will manually move the transforms themself to the correct device and dtype (ex/ resample = transforms.Resample; resample = resample.to(device=torch.device('cuda'), dtype=torch.float16). This PR removes the moving of the cached kernel to be of the correct device and dtype at every call to it, which will now throw an error if a user does not move the transform to CUDA but calls the function on a CUDA waveform.

This PR additionally results in slight differences in results, because of the precision of the kernel. In the previous/functional implementation, the kernel computation is done in the dtype corresponding to the waveform from the start, but in the new transforms implementation, the kernel computation will be done in float64 before being moved and cached as float32 in __init__, prior to the user moving it to the correct dtype themself. This results in higher precision resampling using transforms when resampling on waveforms of dtype smaller than float32, and slightly lower precision resampling on waveforms of dtype greater than float32, since it was intermediately cached as float32.

cc #1487

Copy link
Collaborator

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@@ -1360,7 +1362,8 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
idx_dtype = dtype if dtype is not None else torch.float64
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if someone passes a low precision type like uint8? I think it might be better to pick whatever dtype is most efficient for this operation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following offline discussion, we can keep higher precision type float64 because the kernel computation is a one-time computation, whose dimensions are limited to roughly orig_freq // gcd x new_freq // gcd. normal resampling frequencies will generally have large gcd, in which case dtype differences will have minor computation differences

return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to just return the kernel and do the dtype and device cast after the callsite, since you're not using dtype outside of arange.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following offline discussion, it is fine to convert to this generally "default" type prior to returning the kernel to the call from transforms

@carolineechen carolineechen merged commit 079b3f5 into pytorch:master May 19, 2021
@mthrok
Copy link
Collaborator

mthrok commented May 22, 2021

Maybe we can override to method, so that when the target dtype is 64bit, we can regenerate kernel.

@mthrok mthrok mentioned this pull request May 25, 2021
6 tasks
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
* Update rpc_ddp_tutorial.rst

Replace `RRef` with `RemoteModule` in some descriptions.

* Update rpc_ddp_tutorial.rst

typo fix

Co-authored-by: Holly Sweeney <77758406+holly1238@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants