In [1]:
import torch
print(torch.__version__)

2.1.0+cu121


# roll

Roll the tensor input along the given dimension(s).

Elements that are shifted beyond the last position are re-introduced at the first position. If dims is None, the tensor will be flattened before rolling and then restored to the original shape.

## understand 1d data

In [2]:
input = torch.arange(12)

In [3]:
print(input)
print(input.shape)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
torch.Size([12])


In [4]:
torch.roll(input,shifts=1) #forward move

tensor([11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [5]:
torch.roll(input,shifts=2)

tensor([10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9])

In [6]:
torch.roll(input,shifts=-2) #backward move

tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1])

## understand 2d data

In [7]:
input = torch.arange(12).view(4,3)

In [8]:
print(input)
print(input.shape)

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
torch.Size([4, 3])


In [9]:
torch.roll(input,shifts=1,dims=0)

tensor([[ 9, 10, 11],
        [ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]])

In [10]:
torch.roll(input,shifts=2,dims=0)

tensor([[ 6,  7,  8],
        [ 9, 10, 11],
        [ 0,  1,  2],
        [ 3,  4,  5]])

In [11]:
torch.roll(input,shifts=1,dims=1)

tensor([[ 2,  0,  1],
        [ 5,  3,  4],
        [ 8,  6,  7],
        [11,  9, 10]])

## understand 3d data

In [12]:
input = torch.arange(24).view(2,4,3)

In [13]:
print(input)
print(input.shape)

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])
torch.Size([2, 4, 3])


In [14]:
torch.roll(input,shifts=1,dims=1)

tensor([[[ 9, 10, 11],
         [ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[21, 22, 23],
         [12, 13, 14],
         [15, 16, 17],
         [18, 19, 20]]])

# behave like pandas fill nan

In [32]:
rolled_tensor = torch.roll(input,shifts=1,dims=1).float()
print(rolled_tensor)

tensor([[[ 9., 10., 11.],
         [ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.]],

        [[21., 22., 23.],
         [12., 13., 14.],
         [15., 16., 17.],
         [18., 19., 20.]]])


In [33]:
index = torch.arange(1,dtype=torch.int64)

In [34]:
print(index)

tensor([0])


In [35]:
torch.tensor([0,2]).dtype

torch.int64

In [36]:
import numpy as np

In [38]:
nan_rolled_tensor = torch.index_fill(rolled_tensor,dim=1,index=index,value=np.nan)
print(nan_rolled_tensor)

tensor([[[nan, nan, nan],
         [ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.]],

        [[nan, nan, nan],
         [12., 13., 14.],
         [15., 16., 17.],
         [18., 19., 20.]]])
