In [1]:
import torch

In [2]:
x = torch.arange(9)
x.shape

torch.Size([9])

In [3]:
x_3x3_view = x.view(3, 3) # only works for contiguous tensors
x_3x3_reshape = x.reshape(3, 3) # copy the tensor to make it contiguously stored
                                # (might come with some performance loss)

print(x_3x3_view)
print(x_3x3_reshape)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])


In [4]:
x = torch.arange(9)
x = x.view(3, 3)

y = x.T # x.t()
print(y)
y.is_contiguous()

tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])


False

In [5]:
y.reshape(1, -1) # view() will cause an error for incontiguous tensors, reshape() won't

tensor([[0, 3, 6, 1, 4, 7, 2, 5, 8]])

In [6]:
x = torch.arange(9).view(3, 3)
y = x.T.contiguous()

print(y)
y.view(9) # calling contiguous() before view() works

tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])


tensor([0, 3, 6, 1, 4, 7, 2, 5, 8])

In [7]:
# Add dimensions of two tensors together (concat)
x1 = torch.rand(2, 5)
x2 = torch.rand(2, 5)

z1 = torch.cat((x1, x2), dim=0)
z2 = torch.cat((x1, x2), dim=1)

print(x1)
print(x2)
print()
print(z1)
print(z2)
print(z1.shape, z2.shape)

tensor([[0.3127, 0.4596, 0.4548, 0.4915, 0.0937],
        [0.3594, 0.9960, 0.6443, 0.9147, 0.7598]])
tensor([[0.6737, 0.3673, 0.1451, 0.4760, 0.5508],
        [0.1113, 0.5673, 0.9159, 0.8427, 0.4523]])

tensor([[0.3127, 0.4596, 0.4548, 0.4915, 0.0937],
        [0.3594, 0.9960, 0.6443, 0.9147, 0.7598],
        [0.6737, 0.3673, 0.1451, 0.4760, 0.5508],
        [0.1113, 0.5673, 0.9159, 0.8427, 0.4523]])
tensor([[0.3127, 0.4596, 0.4548, 0.4915, 0.0937, 0.6737, 0.3673, 0.1451, 0.4760,
         0.5508],
        [0.3594, 0.9960, 0.6443, 0.9147, 0.7598, 0.1113, 0.5673, 0.9159, 0.8427,
         0.4523]])
torch.Size([4, 5]) torch.Size([2, 10])


In [8]:
# Unroll or flatten the tensors
z = x1.view(-1) # unroll x1 into one long vector with 10 elements
z

tensor([0.3127, 0.4596, 0.4548, 0.4915, 0.0937, 0.3594, 0.9960, 0.6443, 0.9147,
        0.7598])

In [9]:
batch = 64
x = torch.rand((batch, 2, 5))
z = x.view(batch, -1) # keep the batch size as is and unroll others (used super often)
z.shape

torch.Size([64, 10])

In [10]:
# Switch x axis (permute, transpose)
x = torch.rand(batch, 2, 5)
z1 = x.permute(0, 2, 1)
z2 = x.transpose(0, 2) # switch dimension 0 and dimension 2
                       # (only works when switching two)

print(x.shape, z1.shape, z2.shape)

torch.Size([64, 2, 5]) torch.Size([64, 5, 2]) torch.Size([5, 2, 64])


In [11]:
# Split dimensions into chuncks
z1 = torch.chunk(x, chunks=2, dim=1)
z2 = torch.chunk(x, chunks=2, dim=2)

print(z1[0].shape, z1[0].shape)
print(z2[0].shape, z2[1].shape) # split into 3 and 2 since 5 can't be divided by 2

torch.Size([64, 1, 5]) torch.Size([64, 1, 5])
torch.Size([64, 2, 3]) torch.Size([64, 2, 2])


In [12]:
# Add dimensions (unsqueeze)
x = torch.arange(10)
x1 = x.unsqueeze(0) # (1 x 10)
x2 = x.unsqueeze(1) # (10 x 1)
x3 = x.unsqueeze(0).unsqueeze(1) # (1 x 1 x 10)

print(x.shape, x1.shape, x2.shape, x3.shape)

torch.Size([10]) torch.Size([1, 10]) torch.Size([10, 1]) torch.Size([1, 1, 10])


In [13]:
# Remove dimensions (squeeze)
z1 = x3.squeeze(0)
z2 = x3.squeeze(1)
z3 = x3.squeeze(2) # only works when the assigned dimension is 1
z4 = x3.squeeze() # removes all dimensions with 1

print(z1.shape, z2.shape, z3.shape, z4.shape)

torch.Size([1, 10]) torch.Size([1, 10]) torch.Size([1, 1, 10]) torch.Size([10])
