In [1]:
import torch

In [7]:
x = torch.arange(9) # (9,)

m1 = x.view(3,3) # contiguous memory blocks
m2 = x.reshape(3,3)

print(m1)
print(m2)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])


### transpose
<code>x.t()</code>

In [10]:
m = torch.arange(9).reshape(3,3)
print(m)

print(m.t())

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])


### concatenate
<code>torch.cat((x1,x2), dim=?)</code>

In [18]:
x1 = torch.rand((2,5))
x2 = torch.rand((2,5))

print(x1)
print(x2)
print('-'*60)

z1 = torch.cat((x1,x2), dim=0)
print(z1)

z2 = torch.cat((x1,x2), dim=1)
print(z2)

tensor([[0.3651, 0.6042, 0.3732, 0.9928, 0.1846],
        [0.2577, 0.4603, 0.0600, 0.2057, 0.2658]])
tensor([[0.9991, 0.1942, 0.7372, 0.4284, 0.7426],
        [0.6220, 0.4031, 0.6327, 0.1817, 0.6182]])
------------------------------------------------------------
tensor([[0.3651, 0.6042, 0.3732, 0.9928, 0.1846],
        [0.2577, 0.4603, 0.0600, 0.2057, 0.2658],
        [0.9991, 0.1942, 0.7372, 0.4284, 0.7426],
        [0.6220, 0.4031, 0.6327, 0.1817, 0.6182]])
tensor([[0.3651, 0.6042, 0.3732, 0.9928, 0.1846, 0.9991, 0.1942, 0.7372, 0.4284,
         0.7426],
        [0.2577, 0.4603, 0.0600, 0.2057, 0.2658, 0.6220, 0.4031, 0.6327, 0.1817,
         0.6182]])


### unroll

In [27]:
x = torch.rand((2,5))

z = x.reshape(-1)
print("x shape:",x.shape)
print("unrolled:",z.shape)

print("-"*60)

batch = 64
x = torch.rand((batch, 2, 5))
z = x.reshape(batch,-1)
print("x shape: ",x.shape)
print("unrolled axis 1,2:",z.shape)

x shape: torch.Size([2, 5])
unrolled: torch.Size([10])
------------------------------------------------------------
x shape:  torch.Size([64, 2, 5])
unrolled axis 1,2: torch.Size([64, 10])


### switch axis

In [29]:
batch = 64
x = torch.rand((batch, 2, 5))

z = x.permute(0,2,1) # 0->0, 1->2, 2->1 axis permutation
print(x.shape)
print(z.shape)

torch.Size([64, 2, 5])
torch.Size([64, 5, 2])


In [36]:
x = torch.arange(10)
print("x shape:",x.shape)
print("x.unsqueeze(0):",x.unsqueeze(0).shape)
print("x.unsqueeze(1):",x.unsqueeze(1).shape)

x shape: torch.Size([10])
x.unsqueeze(0): torch.Size([1, 10])
x.unsqueeze(1): torch.Size([10, 1])


In [45]:
x = torch.arange(10).unsqueeze(0).unsqueeze(1)
print(x.shape)

torch.Size([1, 1, 10])


In [48]:
x = torch.rand(2,2)
print(x.shape)

x.unsqueeze(1).shape

torch.Size([2, 2])


torch.Size([2, 1, 2])

In [61]:
x = torch.arange(10).unsqueeze(0).unsqueeze(1)
print(x.shape)
z = x.squeeze(0)
print(z.shape)

torch.Size([1, 1, 10])
torch.Size([1, 10])
