-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Unstructured Interpolation for PyTorch Tensor #1552
Comments
I believe you can work with UpSampling for GPU accelerated interpolation of your images. |
@kvrd18 Thanks a lot for answering my question. I didn't made it clear in my original post. I actually only need to repeatedly shift the Tensor to some off-grid point. For example, given 1-D signal sampled at integer point (0, 1, 2, etc.), we want the interpolated values at non-integer points (0.2, 1.4, 2.9, etc). This is an interpolation problem in general. UpSampling should solve the problem with maybe an unnecessary cost. One can easily upsample the tensor high enough and pick the sample point as needed. However, it will be nice if the those Interpolators (nearest, linear, bilinear, cubic spline etc.) are available directly in the PyTorch API just like what's offered in Numpy/Scipy and MATLAB. Maybe it does not make too much sense for neural networks to have this kind of functions, but it will help when you use PyTorch as a general numpy GPU support. I do have some ideas about implementing this and would like to ask if people in PyTorch community if it is a good idea to include those features. @apaszke and @soumith Forgive me to tag you guys directly if you are not interested in this, but I wanna hear your opinions. Do you guys think PyTorch should support more Numpy features and make it more of a general purpose GPU computing module or it mostly intend to be a neural network tool that complete with TensorFlow with dynamic computational graph. Thanks a lot. |
It seems that there is not much interest in this feature in PyTorch community. I will start a new project instead to enable it outside PyCharm. |
in the long term i'd def like to see interpolation algorithms in pytorch. I'll keep this issue open. |
@soumith Thanks for confirming that. I will work on it any way. It is nice to have it inside PyTorch so I don't need to install a different module. |
+1. (Just spent a couple days working around the lack of interpolation in pytorch.) There are several forms of interpolation that might be implemented, and they are all useful in different situations. Here is a list of what's in numpy/scipy. The first three have been useful to me; it would be super-nice to have GPU versions of them. numpy.interp - 1d interpolation - nice for representing custom functions (e.g., probability densities) built from data. |
we have a 2D grid interpolator now, via: http://pytorch.org/docs/0.3.0/nn.html#torch.nn.functional.grid_sample it works for 4D images (NxCxHxW) and works on HxW dimensions |
@soumith Thanks for making the grid_sample function! I am confused of the usage of grid in this function. Could you give me an example on how to use this function? Is this function available in pytorch 0.3 or 0.4? I have a torch.autograd.Variable with the size (batch_size, H_in, W_in). I want to resize this Variable and get the new Variable with the size (batch_size, H_out, W_out). Do you know how to do that? PS: I found torchvision.transforms has the Resize() function, but it only works for PIL images or numpy array. I need a similar resize function which can be applied directly on the autograd.Variable. Could you confirm that the grid_sample function can work for my situation? Thanks in advance!! |
Hi all -- for a project I'm working on, I made a simple PyTorch bilinear interpolation function, benchmarked it vs. a comparable numpy implementation, and also wrapped the In case it's useful to anybody in the PyTorch community I've documented it here: https://gist.github.com/peteflorence/a1da2c759ca1ac2b74af9a83f69ce20e @JunjieHu this includes an example of how to use |
Summary: This PR addresses #5823. * fix docstring: upsample doesn't support LongTensor * Enable float scale up & down sampling for linear/bilinear/trilinear modes. (following SsnL 's commit) * Enable float scale up & down sampling for nearest mode. Note that our implementation is slightly different from TF that there's actually no "align_corners" concept in this mode. * Add a new interpolate function API to replace upsample. Add deprecate warning for upsample. * Add an area mode which is essentially Adaptive_average_pooling into resize_image. * Add test cases for interpolate in test_nn.py * Add a few comments to help understand *linear interpolation code. * There is only "*cubic" mode missing in resize_images API which is pretty useful in practice. And it's labeled as hackamonth here #1552. I discussed with SsnL that we probably want to implement all new ops in ATen instead of THNN/THCUNN. Depending on the priority, I could either put it in my queue or leave it for a HAMer. * After the change, the files named as *Upsampling*.c works for both up/down sampling. I could rename the files if needed. Differential Revision: D8729635 Pulled By: ailzhang fbshipit-source-id: a98dc5e1f587fce17606b5764db695366a6bb56b
We now have nn.functional.interpolate that does interpolations except for bicubic method. https://pytorch.org/docs/master/nn.html?highlight=interpolate#torch.nn.functional.interpolate |
I feel like this is not what I would call a general interpolation method: In the case of 1d for me, there is still a need for a function that would do This is similar to the scipy This doesn't seem to be existing in Pytorch. I could code it myself, but I find another problem, which is: there is no pytorch So I'm not really sure that these developments actually addressed the issue of the lack of general purpose interpolation functions in pytorch thanks a lot |
@aliutkus you make a good point. I'm reopening this for us to increase the scope of the feature and get the implementations. |
OK, so, I have spend some time these last days implementing a pytorch extension for doing 1-D interpolation. To this purpose:
Depending on the desired dimensionality of the problem, the speedup gained by using the GPU can be quite cool:
I hope this helps some one ! Cheers |
Summary: This PR addresses pytorch#5823. * fix docstring: upsample doesn't support LongTensor * Enable float scale up & down sampling for linear/bilinear/trilinear modes. (following SsnL 's commit) * Enable float scale up & down sampling for nearest mode. Note that our implementation is slightly different from TF that there's actually no "align_corners" concept in this mode. * Add a new interpolate function API to replace upsample. Add deprecate warning for upsample. * Add an area mode which is essentially Adaptive_average_pooling into resize_image. * Add test cases for interpolate in test_nn.py * Add a few comments to help understand *linear interpolation code. * There is only "*cubic" mode missing in resize_images API which is pretty useful in practice. And it's labeled as hackamonth here pytorch#1552. I discussed with SsnL that we probably want to implement all new ops in ATen instead of THNN/THCUNN. Depending on the priority, I could either put it in my queue or leave it for a HAMer. * After the change, the files named as *Upsampling*.c works for both up/down sampling. I could rename the files if needed. Differential Revision: D8729635 Pulled By: ailzhang fbshipit-source-id: a98dc5e1f587fce17606b5764db695366a6bb56b
Hi all. Just like aliutkus mentioned, current interpolation in pytorch always assumed regular sampling grid. Besides 1D case, I think it is also necessary to extend the unstructured interpolation to 2D, which would be similar to scipy.interpolate.griddata. Because sometimes we need forward warping, not just backward warping. I want to ask if this feature has been achieved in pytorch or if it would be possible in future? |
Well actually, you can also use np.interp() to process torch type data. But the type of return value is numpy. Then you change it to torch type~ |
Hello, I am also interested in interpolation schemes. In my current work there is a need for 1D interpolation and we currently use scipy's |
Are there updates on this or any way I could help?
where x is N x D-dimensional and y N x 1 dimensional |
If it helps anyone who got here by Googling (like I did) - I've ended up writing my own approach for performing natural cubic spline interpolation: It includes support for:
|
I implemented this using torch functions, so it's automatically GPU compatible & fully differentiable. I could probably port it to C++ and make it a bit faster. https://github.com/sbarratt/torch_interpolations For interpolating 4.5 million points on a 300 x 300 grid, it takes:
(on a 1080 TI and i7-8700K.) |
Below link may be helpful in ND irregular grid interpolation. |
Thank you everyone who participated in this issue. Since it's an older issue and covers many requests, I've filed more modern issues tracking each request. It's helpful when issues map to a specific request. First, the good news: PyTorch now has torch.searchsorted! Issues to track the other requests made in this thread:
I believe that captures all the requests? Please let me know if I'm mistaken, and I encourage everyone to continue filing issues requesting NumPy or SciPy (or MatLab or...) functions that would be helpful when writing PyTorch programs. I also created a tracking issue for interpolation in PyTorch here #50341, so that we can continue to have one place to review and discuss these related issues and ideas. Closing this issue in favor of the new specific issues and the tracking issue. |
Here is an implementation of N-D Lagrange interpolation. https://github.com/ZichaoLong/aTEAM/blob/master/nn/modules/Interpolation.py |
Edit: @perrette implemented a correct functional form. |
@TobiasJacob thanks for sharing, however, it seems it does not work great in edge cases (like on the points):
|
The above function is definitely wrong.
|
…ch#1552) * Fix ComputeAtRootDomainMap with broadcast in view root domains Fixes pytorch#1549
To whomsoever this helps, here is my implementation of spicy grid data with linear mode. Tessellation itself is non-differentiable. Once we find the coordinates of the enclosing triangle, this function will find the barycentric coordinates for a query point and return the interpolated value.
To test the above implementation:
|
might be useful for some cases: |
To build off of what
|
I recently discovered the PyTorch Tensor and am very excited about its GPU support. However, it seems to support no interpolation algorithms as MATLAB / Numpy does. My application requires a pre-processing step using linear interpolation of the input data. It is OK now that this pre-processing is done in CPU using scipy.interpolate, but it will be nicer if there is something inside PyTorch Tensor that supports doing that operation inside GPU since they will be load into GPU eventually.
Since I am pretty new to PyTorch, I may be wrong that interpolation is an existing support operator (I did check the doc and found nothing about it). If not, do you guys think it is a good operation to be done in GPU? Is it something that you guys would like to include in PyTorch. If so, I can open a PR and start to working on it.
cc @ezyang @gchanan @zou3519 @mruberry @rgommers @heitorschueroff
The text was updated successfully, but these errors were encountered: