In [3]:
import torch
from torch.nn import functional as F

## indexing

In [2]:
x = torch.rand(5, 3)
x

tensor([[0.2786, 0.3781, 0.0827],
        [0.7273, 0.4358, 0.6287],
        [0.0186, 0.4699, 0.6678],
        [0.8362, 0.8000, 0.5068],
        [0.0313, 0.6464, 0.5300]])

In [3]:
x.shape

torch.Size([5, 3])

In [4]:
x.shape[1]

3

In [5]:
x[0]

tensor([0.2786, 0.3781, 0.0827])

In [6]:
p = torch.zeros(10)
p

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [7]:
y = torch.tensor([3, 8, 3, 3, 8, 5])
y

tensor([3, 8, 3, 3, 8, 5])

In [8]:
p.index_add_(0, y, torch.ones(y.shape[0]))
p

tensor([0., 0., 0., 3., 0., 1., 0., 0., 2., 0.])

## dimensions manipulation

In [9]:
x = torch.tensor([[0, 1, 1, 1], [2, 2, 0, 0]])
print(x.shape)
x

torch.Size([2, 4])


tensor([[0, 1, 1, 1],
        [2, 2, 0, 0]])

In [17]:
y = F.one_hot(x)
print(y.shape)
y

torch.Size([2, 4, 3])


tensor([[[1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0]],

        [[0, 0, 1],
         [0, 0, 1],
         [1, 0, 0],
         [1, 0, 0]]])

In [18]:
order = list(range(0, y.ndim))
order.insert(1, order[-1])
del order[-1]
order

[0, 2, 1]

In [24]:
t = y.permute(order)
print(t.shape)
t

torch.Size([2, 3, 4])


tensor([[[1, 0, 0, 0],
         [0, 1, 1, 1],
         [0, 0, 0, 0]],

        [[0, 0, 1, 1],
         [0, 0, 0, 0],
         [1, 1, 0, 0]]])

In [20]:
s = torch.tensor([0.1, 0.2, 0.3, 0.4])
s

tensor([0.1000, 0.2000, 0.3000, 0.4000])

In [21]:
t.float() + s

tensor([[[1.1000, 0.2000, 0.3000, 0.4000],
         [0.1000, 1.2000, 1.3000, 1.4000],
         [0.1000, 0.2000, 0.3000, 0.4000]],

        [[0.1000, 0.2000, 1.3000, 1.4000],
         [0.1000, 0.2000, 0.3000, 0.4000],
         [1.1000, 1.2000, 0.3000, 0.4000]]])

In [22]:
t.float().lerp(s, 0.5)

tensor([[[0.5500, 0.1000, 0.1500, 0.2000],
         [0.0500, 0.6000, 0.6500, 0.7000],
         [0.0500, 0.1000, 0.1500, 0.2000]],

        [[0.0500, 0.1000, 0.6500, 0.7000],
         [0.0500, 0.1000, 0.1500, 0.2000],
         [0.5500, 0.6000, 0.1500, 0.2000]]])

## stack vs cat

In [7]:
x = torch.zeros(5, 3)
x

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [8]:
y = torch.ones(5, 3)
y

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [21]:
c = torch.cat((x, y), dim = 0)
c

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [19]:
s = torch.stack((x, y), dim=0)
s

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

In [20]:
torch.mean(s, dim=0)

tensor([[0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000]])

## argsort

In [22]:
a = torch.rand((5, 3))
a

tensor([[0.3620, 0.7000, 0.6432],
        [0.9925, 0.1571, 0.0627],
        [0.5447, 0.4857, 0.4290],
        [0.6885, 0.8240, 0.6484],
        [0.5910, 0.9039, 0.5217]])

In [30]:
b = torch.rand((5, 3, 2))
b

tensor([[[0.0432, 0.4259],
         [0.3029, 0.3995],
         [0.7006, 0.7181]],

        [[0.2982, 0.3586],
         [0.9122, 0.4464],
         [0.7311, 0.8628]],

        [[0.8964, 0.2868],
         [0.9219, 0.2947],
         [0.1656, 0.2994]],

        [[0.6551, 0.0606],
         [0.8325, 0.3129],
         [0.1951, 0.7462]],

        [[0.4631, 0.7060],
         [0.9481, 0.1694],
         [0.1584, 0.3764]]])

In [32]:
x = torch.tensor([10, 5, 0, 20, 30])
x

tensor([10,  5,  0, 20, 30])

In [33]:
s = torch.argsort(x)
s

tensor([2, 1, 0, 3, 4])

In [34]:
torch.index_select(a, dim=0, index=s)

tensor([[0.5447, 0.4857, 0.4290],
        [0.9925, 0.1571, 0.0627],
        [0.3620, 0.7000, 0.6432],
        [0.6885, 0.8240, 0.6484],
        [0.5910, 0.9039, 0.5217]])

In [36]:
torch.index_select(b, dim=0, index=s)

tensor([[[0.8964, 0.2868],
         [0.9219, 0.2947],
         [0.1656, 0.2994]],

        [[0.2982, 0.3586],
         [0.9122, 0.4464],
         [0.7311, 0.8628]],

        [[0.0432, 0.4259],
         [0.3029, 0.3995],
         [0.7006, 0.7181]],

        [[0.6551, 0.0606],
         [0.8325, 0.3129],
         [0.1951, 0.7462]],

        [[0.4631, 0.7060],
         [0.9481, 0.1694],
         [0.1584, 0.3764]]])

## equality

In [41]:
a = torch.tensor([10, 5, 0, 20, 30])
b = torch.tensor([10, 5, 0, 20, 30])

x = torch.eq(a, b)

print(x)
print(x.all())
print(bool(x.all()))

tensor([True, True, True, True, True])
tensor(True)
True


## patching

In [21]:
x = torch.zeros((2, 2, 3, 3))
x

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

In [22]:
y = torch.ones((2, 2, 3, 3))
y

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])

In [23]:
t = y.clone()
t

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])

In [24]:
t[:, :, 0:2, 0:2] = x[:, :, 0:2, 0:2]
t

tensor([[[[0., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]],

         [[0., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]],

         [[0., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]]])