## PyTorch Tensor Slicing and Concatenation

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing

Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])

print(x)
print(x.size())

tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]])
torch.Size([3, 2, 2])


Access to certain dimension.

In [3]:
print(x[0])

tensor([[1., 2.],
        [3., 4.]])


In [4]:
print(x[0, :])

tensor([[1., 2.],
        [3., 4.]])


In [5]:
print(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]])


In [6]:
print(x[-1])

tensor([[ 9., 10.],
        [11., 12.]])


In [7]:
print(x[-1, :, :])

tensor([[ 9., 10.],
        [11., 12.]])


In [8]:
print(x[:, 0])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


In [9]:
print(x[:, 0, :])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


Access by range. Note that the number of dimensions would not be changed.

In [10]:
print(x[1:3, :, :])
print(x[1:3, :, :].size())
# |x| = (3,2,2)

tensor([[[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]])
torch.Size([2, 2, 2])


In [11]:
print(x[:, :1, :])
print(x[:, :1, :].size())

tensor([[[ 1.,  2.]],

        [[ 5.,  6.]],

        [[ 9., 10.]]])
torch.Size([3, 1, 2])


In [12]:
print(x[:, :-1, :])
print(x[:, :-1, :].size())

tensor([[[ 1.,  2.]],

        [[ 5.,  6.]],

        [[ 9., 10.]]])
torch.Size([3, 1, 2])


### split: Split tensor to desirable shapes.

In [13]:
x = torch.FloatTensor(10, 4)
x

tensor([[ 7.5541e-19,  3.0998e-41,  8.2092e+13,  4.3948e-41],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.4013e-45,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.4013e-45,  0.0000e+00, -1.7014e+38,  1.1515e-40],
        [ 4.5919e-41,  4.1478e-43,  1.5835e-43,  0.0000e+00],
        [ 4.7661e-18,  3.0998e-41,  4.7675e-18,  3.0998e-41]])

In [14]:
# splits()
# 텐서를 특정 차원에 대해 원하는 크기로 잘라줌
# 개수에 상관 없이 원하는 크기로 나누기

splits = x.split(4, dim=0)
print(splits)

for s in splits:
    print(s.size())

(tensor([[7.5541e-19, 3.0998e-41, 8.2092e+13, 4.3948e-41],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]), tensor([[ 1.4013e-45,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.4013e-45,  0.0000e+00, -1.7014e+38,  1.1515e-40]]), tensor([[4.5919e-41, 4.1478e-43, 1.5835e-43, 0.0000e+00],
        [4.7661e-18, 3.0998e-41, 4.7675e-18, 3.0998e-41]]))
torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([2, 4])


### chunk: Split tensor to number of chunks.

In [15]:
x = torch.FloatTensor(8, 4)
x

tensor([[0.0000e+00, 3.0998e-41, 1.8788e+31, 1.7220e+22],
        [2.1715e-18, 2.1391e+23, 5.4455e-05, 5.3779e+22],
        [5.3697e-05, 1.0547e-08, 4.2190e-08, 5.4371e+22],
        [1.3311e+22, 1.0930e-05, 2.1707e-18, 1.6678e+19],
        [7.0976e+22, 2.1715e-18, 4.2330e+21, 1.6534e+19],
        [1.1625e+27, 1.4580e-19, 7.1856e+22, 4.3605e+27],
        [1.5766e-19, 7.1856e+22, 4.3605e+27, 1.4580e-19],
        [1.8179e+31, 1.8524e+28, 2.1715e-18, 2.1391e+23]])

In [16]:
# chunk : 첫번째 차원의 크기 8을 최대한 같은 크기로 3등분
chunks = x.chunk(3, dim=0)
print(chunks)

for c in chunks:
    print(c.size())

(tensor([[0.0000e+00, 3.0998e-41, 1.8788e+31, 1.7220e+22],
        [2.1715e-18, 2.1391e+23, 5.4455e-05, 5.3779e+22],
        [5.3697e-05, 1.0547e-08, 4.2190e-08, 5.4371e+22]]), tensor([[1.3311e+22, 1.0930e-05, 2.1707e-18, 1.6678e+19],
        [7.0976e+22, 2.1715e-18, 4.2330e+21, 1.6534e+19],
        [1.1625e+27, 1.4580e-19, 7.1856e+22, 4.3605e+27]]), tensor([[1.5766e-19, 7.1856e+22, 4.3605e+27, 1.4580e-19],
        [1.8179e+31, 1.8524e+28, 2.1715e-18, 2.1391e+23]]))
torch.Size([3, 4])
torch.Size([3, 4])
torch.Size([2, 4])


### index_select: Select elements by using dimension index.

In [17]:
# index_select() : 특정 차원에서 원하는 인덱스(index) 값만 취하는 함수

x = torch.FloatTensor([[[1, 1],
                        [2, 2]],
                       [[3, 3],
                        [4, 4]],
                       [[5, 5],
                        [6, 6]]])

indice = torch.LongTensor([2, 1])

print(x)
print(x.size())

tensor([[[1., 1.],
         [2., 2.]],

        [[3., 3.],
         [4., 4.]],

        [[5., 5.],
         [6., 6.]]])
torch.Size([3, 2, 2])


In [18]:
y = x.index_select(dim=0, index=indice) #첫번째 차원에서 인덱스가 2와 1인것

print(y)
print(y.size())

tensor([[[5., 5.],
         [6., 6.]],

        [[3., 3.],
         [4., 4.]]])
torch.Size([2, 2, 2])


### cat: Concatenation of multiple tensors in the list.

In [19]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])

print(x)
print()
print(y)
print(x.size(), y.size())

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

tensor([[10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([3, 3]) torch.Size([3, 3])


In [20]:
z = torch.cat([x, y], dim=0)
# dim=0 세로방향
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


In [21]:
z = torch.cat([x, y], dim=-1)

print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


### stack: Stacking of multiple tensors in the list.

In [22]:
# stack() : 쌓기 (2차원 아님 >> 차원이 증가)
z = torch.stack([x, y])

print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


Or you can specify the dimension. Default is 0.

In [23]:
z = torch.stack([x, y], dim=-1)

print(z)
print(z.size())

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])


### Implement 'stack' function by using 'cat'.

In [24]:
print(x)
print(x.size())
print(x.unsqueeze(0))
print(x.unsqueeze(0).size())

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
torch.Size([3, 3])
tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])
torch.Size([1, 3, 3])


In [25]:
print(y)
print(y.size())
print(y.unsqueeze(0))
print(y.unsqueeze(0).size())

tensor([[10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([3, 3])
tensor([[[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([1, 3, 3])


In [26]:
# z = torch.stack([x, y])
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)

print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


### Useful Trick: Merge results from iterations

In [27]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

result = torch.stack(result)
print(result)
result.size()

tensor([[[5.5077e+03, 4.3949e-41],
         [4.7720e-18, 3.0998e-41]],

        [[5.5077e+03, 4.3949e-41],
         [4.7737e-18, 3.0998e-41]],

        [[5.5077e+03, 4.3949e-41],
         [1.4865e-21, 3.0998e-41]],

        [[5.5077e+03, 4.3949e-41],
         [4.7797e-18, 3.0998e-41]],

        [[5.5077e+03, 4.3949e-41],
         [5.5077e+03, 4.3949e-41]]])


torch.Size([5, 2, 2])

In [28]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    print(x)
    result += [x]


tensor([[5.5077e+03, 4.3949e-41],
        [4.7720e-18, 3.0998e-41]])
tensor([[5.5077e+03, 4.3949e-41],
        [5.5077e+03, 4.3949e-41]])
tensor([[0., 0.],
        [0., 0.]])
tensor([[0., 0.],
        [0., 0.]])
tensor([[5.5076e+03, 4.3949e-41],
        [5.5076e+03, 4.3949e-41]])
