Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Unstructured Interpolation for PyTorch Tensor #1552

Closed
lijunzh opened this issue May 14, 2017 · 47 comments
Closed

[Feature Request] Unstructured Interpolation for PyTorch Tensor #1552

lijunzh opened this issue May 14, 2017 · 47 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. hackamonth high priority module: interpolation module: numpy Related to numpy support, and also numpy compatibility of our operators quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lijunzh
Copy link

lijunzh commented May 14, 2017

I recently discovered the PyTorch Tensor and am very excited about its GPU support. However, it seems to support no interpolation algorithms as MATLAB / Numpy does. My application requires a pre-processing step using linear interpolation of the input data. It is OK now that this pre-processing is done in CPU using scipy.interpolate, but it will be nicer if there is something inside PyTorch Tensor that supports doing that operation inside GPU since they will be load into GPU eventually.

Since I am pretty new to PyTorch, I may be wrong that interpolation is an existing support operator (I did check the doc and found nothing about it). If not, do you guys think it is a good operation to be done in GPU? Is it something that you guys would like to include in PyTorch. If so, I can open a PR and start to working on it.

cc @ezyang @gchanan @zou3519 @mruberry @rgommers @heitorschueroff

@kiranvaidhya
Copy link

I believe you can work with UpSampling for GPU accelerated interpolation of your images.

@lijunzh
Copy link
Author

lijunzh commented May 16, 2017

@kvrd18 Thanks a lot for answering my question. I didn't made it clear in my original post. I actually only need to repeatedly shift the Tensor to some off-grid point. For example, given 1-D signal sampled at integer point (0, 1, 2, etc.), we want the interpolated values at non-integer points (0.2, 1.4, 2.9, etc). This is an interpolation problem in general.

UpSampling should solve the problem with maybe an unnecessary cost. One can easily upsample the tensor high enough and pick the sample point as needed. However, it will be nice if the those Interpolators (nearest, linear, bilinear, cubic spline etc.) are available directly in the PyTorch API just like what's offered in Numpy/Scipy and MATLAB. Maybe it does not make too much sense for neural networks to have this kind of functions, but it will help when you use PyTorch as a general numpy GPU support.

I do have some ideas about implementing this and would like to ask if people in PyTorch community if it is a good idea to include those features. @apaszke and @soumith Forgive me to tag you guys directly if you are not interested in this, but I wanna hear your opinions. Do you guys think PyTorch should support more Numpy features and make it more of a general purpose GPU computing module or it mostly intend to be a neural network tool that complete with TensorFlow with dynamic computational graph. Thanks a lot.

@lijunzh
Copy link
Author

lijunzh commented May 24, 2017

It seems that there is not much interest in this feature in PyTorch community. I will start a new project instead to enable it outside PyCharm.

@lijunzh lijunzh closed this as completed May 24, 2017
@soumith soumith reopened this May 24, 2017
@soumith
Copy link
Member

soumith commented May 24, 2017

in the long term i'd def like to see interpolation algorithms in pytorch. I'll keep this issue open.

@soumith soumith added enhancement todo Not as important as medium or high priority tasks, but we will work on these. labels May 24, 2017
@lijunzh
Copy link
Author

lijunzh commented May 24, 2017

@soumith Thanks for confirming that. I will work on it any way. It is nice to have it inside PyTorch so I don't need to install a different module.

@soumith soumith added this to Low Priority in Issue Status Aug 23, 2017
@soumith soumith added this to nn / autograd / torch in Issue Categories Aug 31, 2017
@davidbau
Copy link

davidbau commented Jan 9, 2018

+1. (Just spent a couple days working around the lack of interpolation in pytorch.) There are several forms of interpolation that might be implemented, and they are all useful in different situations. Here is a list of what's in numpy/scipy. The first three have been useful to me; it would be super-nice to have GPU versions of them.

numpy.interp - 1d interpolation - nice for representing custom functions (e.g., probability densities) built from data.
scipy.ndimage.zoom - similar to torch.nn.UpSampling, except it supports non-integer zoom ratio. Handy for the common case of scaling an image.
scipy.interpolate.RectBivariateSpline/RegularGridInterpolator - allows generalized and irregular grid points, instead of assuming uniform edge-to-edge zooming. Useful when being careful but while the data is still at grid points.
scipy.interpolate.griddata etc - totally generic N-D interpolation by tessellating the space with simplexes. I haven't had an occasion to need this.

@soumith
Copy link
Member

soumith commented Jan 9, 2018

we have a 2D grid interpolator now, via: http://pytorch.org/docs/0.3.0/nn.html#torch.nn.functional.grid_sample

it works for 4D images (NxCxHxW) and works on HxW dimensions

@JunjieHu
Copy link

JunjieHu commented Jan 22, 2018

@soumith Thanks for making the grid_sample function!

I am confused of the usage of grid in this function. Could you give me an example on how to use this function? Is this function available in pytorch 0.3 or 0.4?

I have a torch.autograd.Variable with the size (batch_size, H_in, W_in). I want to resize this Variable and get the new Variable with the size (batch_size, H_out, W_out). Do you know how to do that?

PS: I found torchvision.transforms has the Resize() function, but it only works for PIL images or numpy array. I need a similar resize function which can be applied directly on the autograd.Variable. Could you confirm that the grid_sample function can work for my situation? Thanks in advance!!

@peteflorence
Copy link

Hi all -- for a project I'm working on, I made a simple PyTorch bilinear interpolation function, benchmarked it vs. a comparable numpy implementation, and also wrapped the nn.functional.grid_sample() function to support my same interface.

In case it's useful to anybody in the PyTorch community I've documented it here: https://gist.github.com/peteflorence/a1da2c759ca1ac2b74af9a83f69ce20e

@JunjieHu this includes an example of how to use nn.functional.grid_sample() (see the last section)

facebook-github-bot pushed a commit that referenced this issue Jul 6, 2018
Summary:
This PR addresses #5823.

* fix docstring: upsample doesn't support LongTensor

* Enable float scale up & down sampling for linear/bilinear/trilinear modes. (following SsnL 's commit)

* Enable float scale up & down sampling for nearest mode. Note that our implementation is slightly different from TF that there's actually no "align_corners" concept in this mode.

* Add a new interpolate function API to replace upsample. Add deprecate warning for upsample.

* Add an area mode which is essentially Adaptive_average_pooling into resize_image.

* Add test cases for interpolate in test_nn.py

* Add a few comments to help understand *linear interpolation code.

* There is only "*cubic" mode missing in resize_images API which is pretty useful in practice. And it's labeled as hackamonth here #1552. I discussed with SsnL that we probably want to implement all new ops in ATen instead of THNN/THCUNN. Depending on the priority, I could either put it in my queue or leave it for a HAMer.

* After the change, the files named as *Upsampling*.c works for both up/down sampling. I could rename the files if needed.

Differential Revision: D8729635

Pulled By: ailzhang

fbshipit-source-id: a98dc5e1f587fce17606b5764db695366a6bb56b
@ailzhang
Copy link
Contributor

ailzhang commented Jul 9, 2018

We now have nn.functional.interpolate that does interpolations except for bicubic method. https://pytorch.org/docs/master/nn.html?highlight=interpolate#torch.nn.functional.interpolate
We have another issue for bicubic open here. #918
Closing this for now. Feel free to reopen if you have any questions

@ailzhang ailzhang closed this as completed Jul 9, 2018
@aliutkus
Copy link

aliutkus commented Jul 31, 2018

I feel like this is not what I would call a general interpolation method:
Right now, it looks to me like all these interpolation functions are all kind of applying deformations to regularly sampled input. These deformation consist of either constant (current interpolate) or varying (current sample_grid), but the input is always assumed regularly sampled.

In the case of 1d for me, there is still a need for a function that would do interp1d(x,y,xnew) where x are sorted coordinates, y are the corresponding output values, and xnew are the points you want the linear interpolation from.

This is similar to the scipy interp1d function. Going there, having some interpolation where you could select the dimension along which to interpolate would be great.

This doesn't seem to be existing in Pytorch. I could code it myself, but I find another problem, which is: there is no pytorch searchsorted

So I'm not really sure that these developments actually addressed the issue of the lack of general purpose interpolation functions in pytorch

thanks a lot

@soumith soumith reopened this Jul 31, 2018
@soumith
Copy link
Member

soumith commented Jul 31, 2018

@aliutkus you make a good point. I'm reopening this for us to increase the scope of the feature and get the implementations.

@aliutkus
Copy link

aliutkus commented Aug 3, 2018

OK, so, I have spend some time these last days implementing a pytorch extension for doing 1-D interpolation.

To this purpose:

  1. I implemented a pytorch-searchsorted function, that looks for desired values in sorted arrays. This may be useful for other tasks than just interpolation
  2. I implemented a pytorch-interp1d function, that depends on 1. and that achieves parallelized GPU 1-D interpolation.

Depending on the desired dimensionality of the problem, the speedup gained by using the GPU can be quite cool:

Solving 100000 interpolation problems: each with 100 observations and 30 desired values
CPU: 8060.260ms, GPU: 70.735ms, error: 0.000000%.

I hope this helps some one !

Cheers

goodlux pushed a commit to goodlux/pytorch that referenced this issue Aug 15, 2018
Summary:
This PR addresses pytorch#5823.

* fix docstring: upsample doesn't support LongTensor

* Enable float scale up & down sampling for linear/bilinear/trilinear modes. (following SsnL 's commit)

* Enable float scale up & down sampling for nearest mode. Note that our implementation is slightly different from TF that there's actually no "align_corners" concept in this mode.

* Add a new interpolate function API to replace upsample. Add deprecate warning for upsample.

* Add an area mode which is essentially Adaptive_average_pooling into resize_image.

* Add test cases for interpolate in test_nn.py

* Add a few comments to help understand *linear interpolation code.

* There is only "*cubic" mode missing in resize_images API which is pretty useful in practice. And it's labeled as hackamonth here pytorch#1552. I discussed with SsnL that we probably want to implement all new ops in ATen instead of THNN/THCUNN. Depending on the priority, I could either put it in my queue or leave it for a HAMer.

* After the change, the files named as *Upsampling*.c works for both up/down sampling. I could rename the files if needed.

Differential Revision: D8729635

Pulled By: ailzhang

fbshipit-source-id: a98dc5e1f587fce17606b5764db695366a6bb56b
@jingjin25
Copy link

jingjin25 commented Oct 18, 2018

Hi all. Just like aliutkus mentioned, current interpolation in pytorch always assumed regular sampling grid.

Besides 1D case, I think it is also necessary to extend the unstructured interpolation to 2D, which would be similar to scipy.interpolate.griddata. Because sometimes we need forward warping, not just backward warping.

I want to ask if this feature has been achieved in pytorch or if it would be possible in future?

@Cuiyirui
Copy link

Well actually, you can also use np.interp() to process torch type data. But the type of return value is numpy. Then you change it to torch type~

@mileslucas
Copy link

Hello, I am also interested in interpolation schemes.

In my current work there is a need for 1D interpolation and we currently use scipy's InterpolatedUnivariateSpline interpolator. It seems like the grid_sample method is not quite what we need.

@Haydnspass
Copy link

Are there updates on this or any way I could help?
I would need to compute cubic spline coefficients on 'calibration' data and then evaluate the function on arbitrary points xq.
So it would be nice to construct the interpolation and then query:

interp_model = interpN(x, y, mode='cubic-spline')
yq = interp_model(xq)

where x is N x D-dimensional and y N x 1 dimensional

@patrick-kidger
Copy link

If it helps anyone who got here by Googling (like I did) - I've ended up writing my own approach for performing natural cubic spline interpolation:

torchcubicspline

It includes support for:

  • Batching
  • Missing values
  • GPU support and backprop via PyTorch
  • The ability to directly evaluate the first derivative of the spline.

@rgommers rgommers added the quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. label Apr 20, 2020
@mruberry mruberry added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Apr 28, 2020
@sbarratt
Copy link

I implemented this using torch functions, so it's automatically GPU compatible & fully differentiable. I could probably port it to C++ and make it a bit faster. https://github.com/sbarratt/torch_interpolations

For interpolating 4.5 million points on a 300 x 300 grid, it takes:

PyTorch took 222.531 +\- 7.972 ms
PyTorch Cuda took 12.502 +\- 0.493 ms
Scipy took 421.471 +\- 4.278 ms

(on a 1080 TI and i7-8700K.)

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. module: interpolation and removed feature A request for a proper, new feature. topic: operator labels Oct 10, 2020
@lc82111
Copy link

lc82111 commented Dec 20, 2020

@mruberry
Copy link
Collaborator

Thank you everyone who participated in this issue. Since it's an older issue and covers many requests, I've filed more modern issues tracking each request. It's helpful when issues map to a specific request.

First, the good news: PyTorch now has torch.searchsorted!

Issues to track the other requests made in this thread:

I believe that captures all the requests? Please let me know if I'm mistaken, and I encourage everyone to continue filing issues requesting NumPy or SciPy (or MatLab or...) functions that would be helpful when writing PyTorch programs.

I also created a tracking issue for interpolation in PyTorch here #50341, so that we can continue to have one place to review and discuss these related issues and ideas.

Closing this issue in favor of the new specific issues and the tracking issue.

@ZichaoLong
Copy link

ZichaoLong commented Mar 3, 2021

Here is an implementation of N-D Lagrange interpolation.
Just like scipy.interpolate.interpn, it supports irregular input grid(and requires values on regular grid).
It support arbitrary interpolation order.

https://github.com/ZichaoLong/aTEAM/blob/master/nn/modules/Interpolation.py
https://github.com/ZichaoLong/aTEAM/blob/master/nn/functional/interpolation.py

@TobiasJacob
Copy link

TobiasJacob commented Sep 24, 2021

Edit: @perrette implemented a correct functional form.

@perrette
Copy link

@TobiasJacob thanks for sharing, however, it seems it does not work great in edge cases (like on the points):

interpolate(torch.arange(5), torch.arange(5)**2, 3) yields 7... (instead of expected 9, as np.interp correctly yields...).

@perrette
Copy link

The above function is definitely wrong.
Here a two liners that works, adapted from google/jax#3860 (comment):

def interpolate(x, xp, fp):
    i = torch.clip(torch.searchsorted(xp, x, right=True), 1, len(xp) - 1)
    return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

zasdfgbnm pushed a commit to zasdfgbnm/pytorch that referenced this issue Apr 4, 2022
…ch#1552)

* Fix ComputeAtRootDomainMap with broadcast in view root domains

Fixes pytorch#1549
@RSKothari
Copy link

To whomsoever this helps, here is my implementation of spicy grid data with linear mode. Tessellation itself is non-differentiable. Once we find the coordinates of the enclosing triangle, this function will find the barycentric coordinates for a query point and return the interpolated value.

import torch 
from scipy.interpolate import griddata
from scipy.spatial import Delaunay

class my_griddata():
    def __init__(self, height, width, mode='nearest') -> None:
        
        self.height = height 
        self.width = width 
        self.mode = mode
        
        self.Y_grid, self.X_grid = torch.meshgrid(torch.arange(height),
                                                  torch.arange(width),
                                                  indexing='ij')
        
        self.X_grid = self.X_grid.reshape(-1)
        self.Y_grid = self.Y_grid.reshape(-1)
        pass
    
    def forward(self, X, Y, value):
        
        # Tesselate grid points
        pos = torch.stack([X, Y], dim=-1).detach().cpu().numpy()
        tri = Delaunay(pos, furthest_site=False)
    
        # Find the corners of each simplice
        corners_X = X[tri.simplices]
        corners_Y = Y[tri.simplices]
        corners_F = value[tri.simplices]
        
        # Find simplice ID for each query pixel in the original grid
        pos_orig = torch.stack([self.X_grid, self.Y_grid], dim=-1).detach().cpu().numpy()
        simplice_id = tri.find_simplex(pos_orig)
        
        # Find X,Y,F values of the 3 nearest grid points for each
        # pixel in the original grid
        corners_X_pq = corners_X[simplice_id]
        corners_Y_pq = corners_Y[simplice_id]
        corners_F_pq = corners_F[simplice_id]        
                   
        if self.mode == 'bilinear':
            x1, y1 = corners_X_pq[:, 0], corners_Y_pq[:, 0]
            x2, y2 = corners_X_pq[:, 1], corners_Y_pq[:, 1]
            x3, y3 = corners_X_pq[:, 2], corners_Y_pq[:, 2]
            
            lambda1 = ((y2 - y3) * (self.X_grid - x3) + (x3 - x2) * (self.Y_grid - y3))/ \
                      ((y2 - y3) * (x1 - x3) + (x3 - x2) * (y1 - y3))
            
            lambda2 = ((y3 - y1) * (self.X_grid - x3) + (x1 - x3) * (self.Y_grid - y3))/ \
                      ((y2 - y3) * (x1 - x3) + (x3 - x2) * (y1 - y3))
            
            lambda3 = 1 - lambda1 - lambda2
            
            out = lambda1 * corners_F_pq[:,0] + lambda2 * corners_F_pq[:, 1] + lambda3 * corners_F_pq[:, 2]

        else: 
            import sys 
            sys.exit("Other modes not implemented")
            
        out = out.reshape(self.height, self.width)
        return out 

To test the above implementation:

    H = 256
    W = 256
    
    num_x_pts = 10
    num_y_pts = 10
    
    interp_obj = my_griddata(H, W, mode='bilinear')
       
    xx = torch.linspace(0, W, num_x_pts)
    yy = torch.linspace(0, H, num_y_pts)
    
    Y_grid, X_grid = torch.meshgrid(yy, xx, indexing='ij')
    
    X = X_grid.reshape(-1)
    Y = Y_grid.reshape(-1)
    Z = torch.rand(num_y_pts*num_x_pts, )
    
    out = interp_obj.forward(X, Y, Z)
    
    import matplotlib.pyplot as plt 
    
    fig, axs = plt.subplots(ncols=2, nrows=1)
    
    axs[0].pcolormesh(interp_obj.X_grid.reshape(H, W),
                      interp_obj.Y_grid.reshape(H, W),
                      out,
                      vmin=0,
                      vmax=1)
    
    axs[0].scatter(X, Y, c=Z)
    
    axs[1].scatter(X, Y, c=Z)
    
    points = torch.stack([X, Y], dim=1).numpy()
    points_q = torch.stack([interp_obj.X_grid, interp_obj.Y_grid], dim=1).numpy()
    
    out = griddata(points, Z, points_q,  method='linear')
    out = out.reshape(H, W)
    
    fig, axs = plt.subplots(ncols=2, nrows=1)
    
    axs[0].pcolormesh(interp_obj.X_grid.reshape(H, W),
                      interp_obj.Y_grid.reshape(H, W),
                      out,
                      vmin=0,
                      vmax=1)
    
    axs[0].scatter(X, Y, c=Z)
    
    axs[1].scatter(X, Y, c=Z)
    
    plt.show(block=True)

@eigenvivek
Copy link

@LarsDu
Copy link

LarsDu commented May 25, 2023

To build off of what parette provided:

np.interp has some clipping and the left and right params. I found the following modifications gives you effectively identical results to np.interp (minus handling of complex numbers and periodicity)

def torch_1d_interp(
    x: Tensor,
    xp: Tensor,
    fp: Tensor,
    left: float | None = None,
    right: float | None = None,
) -> Tensor:
    """One-dimensional linear interpolation for monotonically increasing sample points.

    Returns the one-dimensional piecewise linear interpolant to a function with given discrete data points (xp, fp), evaluated at x.

    Args:
        x: The x-coordinates at which to evaluate the interpolated values.
        xp: 1d sequence of floats. x-coordinates. Must be increasing
        fp: 1d sequence of floats. y-coordinates. Must be same length as xp
        left: Value to return for x < xp[0], default is fp[0]
        right: Value to return for x > xp[-1], default is fp[-1]

    Returns:
        The interpolated values, same shape as x.
    """
    if left is None:
        left = fp[0]

    if right is None:
        right = fp[-1]

    i = torch.clip(torch.searchsorted(xp, x, right=True), 1, len(xp) - 1)

    answer = torch.where(
        x < xp[0],
        left,
        (fp[i - 1] * (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1]),
    )
    answer = torch.where(x > xp[-1], right, answer)
    return answer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. hackamonth high priority module: interpolation module: numpy Related to numpy support, and also numpy compatibility of our operators quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Issue Categories
neural-nets
Issue Status
Low Priority
Development

No branches or pull requests