#### Einsum Basics

##### `rearrange`

In [None]:
import torch
from einops import rearrange, repeat

num_channels, height, width = 3, 224, 224
num_height, num_width, patch_height, patch_width =  224//16, 224//16, 16, 16 # 14, 14, 16, 16

batch_size = 2
in_channels = 3
x = torch.randn((batch_size, in_channels, height, width))

# (B, C, H, W) -> (B, num_patches, embed_dim )
# (B, C, H, W) -> (B, C, (nh ph) , (nw pw) ) [can be thought of] -> (B, C, nh, nw, ph, pw) -> (B, nh*nw, (ph * pw * C) ) equiv to (B, (nh nw) (ph pw C) )
out = rearrange( tensor = x, 
                pattern = "b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)",
                c = in_channels,
                b = batch_size,
                ph = patch_height,
                pw = patch_width
                )

print(out.shape) # (b, num_patches, dim = (3*ph*pw) ) -> (2, 14*14, 3*16*16)
num_patch_h, num_patch_w = (height // patch_height), (width // patch_width)
b, c, nh, ph, nw, pw = batch_size, in_channels, num_patch_h, patch_height, num_patch_w, patch_width
out1 = x.reshape( b, c, nh, ph, nw, pw).permute(0, 5, 1, 3, 2, 4).reshape(b, nh*nw, ph*pw*c)
print( out1.shape )

torch.Size([2, 196, 768])
torch.Size([2, 196, 768])


#### `repeat`

In [8]:
x = torch.tensor(1)
out = x.repeat(5)
print(x.shape, out.shape)
print(x, out)

torch.Size([]) torch.Size([5])
tensor(1) tensor([1, 1, 1, 1, 1])


In [10]:
x = torch.tensor([1, 2, 3])
out = x.repeat( 2, 2 ) # dim0 = 2, dim1 = 2
print(x.shape, out.shape)
print(x)
print(out)

torch.Size([3]) torch.Size([2, 6])
tensor([1, 2, 3])
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])


In [None]:
x = torch.tensor([1, 2, 3])
out = x.repeat( 2, 2, 2 ) # dim0 = 2, dim1 = 2, dim2 = 2
print(x)
print(out) # as expected!

tensor([1, 2, 3])
tensor([[[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]],

        [[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]]])


In [13]:
x = torch.tensor([1, 2])
out = x.repeat(2, 2)
print(x.shape, out.shape)

print(x)
print(out)


torch.Size([2]) torch.Size([2, 4])
tensor([1, 2])
tensor([[1, 2, 1, 2],
        [1, 2, 1, 2]])


In [14]:
x = torch.tensor([1, 2])
out = x.repeat(1, 2)
print(x.shape, out.shape)

print(x)
print(out)

torch.Size([2]) torch.Size([1, 4])
tensor([1, 2])
tensor([[1, 2, 1, 2]])


In [16]:
x = torch.tensor([1, 2])
out = x.repeat(2, 1)
print(x.shape, out.shape)

print(x)
print(out)

torch.Size([2]) torch.Size([2, 2])
tensor([1, 2])
tensor([[1, 2],
        [1, 2]])


In [17]:
x = torch.tensor([1, 2, 3])
out = x.repeat(2, 2)
print(x)
print(out)

tensor([1, 2, 3])
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])


In [18]:
x = torch.tensor([1, 2, 3])
out = x.repeat(2, 1)
print(x)
print(out)

tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]])


In [19]:
x = torch.tensor([1, 2, 3])
out = x.repeat(1, 2)
print(x)
print(out)

tensor([1, 2, 3])
tensor([[1, 2, 3, 1, 2, 3]])
