## Learn how to use the einops library

In [1]:
import torch
from einops import repeat, rearrange

## Use repeat function

In [2]:
input = torch.randn(2, 3)
print(input, input.shape)

output = repeat(input, 'h w -> h w c', c=3)
print(output, output.shape)

tensor([[-0.0445,  1.2679,  0.0591],
        [ 0.8185, -0.0033, -1.2111]]) torch.Size([2, 3])
tensor([[[-0.0445, -0.0445, -0.0445],
         [ 1.2679,  1.2679,  1.2679],
         [ 0.0591,  0.0591,  0.0591]],

        [[ 0.8185,  0.8185,  0.8185],
         [-0.0033, -0.0033, -0.0033],
         [-1.2111, -1.2111, -1.2111]]]) torch.Size([2, 3, 3])


### repeat a whole matrix

In [12]:
import torch
from einops import repeat

intrinsics_pad = repeat(torch.eye(4), "X Y -> L X Y", L = 2).clone()
print(intrinsics_pad, intrinsics_pad.shape)

idx = torch.arange(5).reshape(1, 5)
idx_repeat = repeat(idx, "B RN -> B DimX RN", DimX=3)
print(idx, "\n", idx_repeat, idx_repeat.shape)

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]],

        [[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]]) torch.Size([2, 4, 4])
tensor([[0, 1, 2, 3, 4]]) 
 tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]]) torch.Size([1, 3, 5])


### repeat specific axis

In [10]:
pc = torch.randn(1, 2, 3)
print(pc)

### The below three methods are the same
pc = repeat(pc, 'b n c -> (repeat b) n c', repeat=2).contiguous()
# pc = repeat(pc, 'b n c -> (b 2) n c').contiguous()
# pc = repeat(pc, 'b n c -> (2 b) n c').contiguous()
print(pc, pc.shape)

tensor([[[-1.5255, -2.6183,  0.1318],
         [ 0.5025, -0.5046,  0.7587]]])
tensor([[[-1.5255, -2.6183,  0.1318],
         [ 0.5025, -0.5046,  0.7587]],

        [[-1.5255, -2.6183,  0.1318],
         [ 0.5025, -0.5046,  0.7587]]]) torch.Size([2, 2, 3])


In [4]:
class_label = 9

batch_size = 2
input = torch.tensor([class_label])

c_indices = repeat(input, '1 -> b 1', b=batch_size)  # class token
print(c_indices.shape, c_indices.dtype)
print(c_indices)


torch.Size([2, 1]) torch.int64
tensor([[9],
        [9]])


## rearange operation

In [16]:
import torch
from einops import repeat, rearrange

input = torch.arange(18).reshape(1, 3, 2, 3)
print(input, input.shape)

output = rearrange(input, 'b n h w -> b (n h w)')
print(output, output.shape)

output = rearrange(input, 'b n h w -> b (h w) n')
print(output, output.shape)

output2 = repeat(output, 'b g n -> b (g 2) n')
# output2 = repeat(output2, 'b r g n -> b (g r) n')
print(output2, output2.shape)


# output = rearrange(input, 'b n c -> b (c n)')  # it's different from the above one
# print(output, output.shape)

# output = repeat(input, 'b n c -> b (c n 3)')  # it's different from the above one
# print(output, output.shape)

tensor([[[[ 0,  1,  2],
          [ 3,  4,  5]],

         [[ 6,  7,  8],
          [ 9, 10, 11]],

         [[12, 13, 14],
          [15, 16, 17]]]]) torch.Size([1, 3, 2, 3])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17]]) torch.Size([1, 18])
tensor([[[ 0,  6, 12],
         [ 1,  7, 13],
         [ 2,  8, 14],
         [ 3,  9, 15],
         [ 4, 10, 16],
         [ 5, 11, 17]]]) torch.Size([1, 6, 3])
tensor([[[ 0,  6, 12],
         [ 0,  6, 12],
         [ 1,  7, 13],
         [ 1,  7, 13],
         [ 2,  8, 14],
         [ 2,  8, 14],
         [ 3,  9, 15],
         [ 3,  9, 15],
         [ 4, 10, 16],
         [ 4, 10, 16],
         [ 5, 11, 17],
         [ 5, 11, 17]]]) torch.Size([1, 12, 3])


In [None]:
import torch
from einops import repeat, rearrange

input = torch.arange(12).reshape(2, 3, 2)
print(input, input.shape)

output = rearrange(input, 'b n c -> b (n c)')
print(output, output.shape)

output = repeat(input, 'b n c -> b (n c 2)')  # it's different from the above one
print(output, output.shape)

output = rearrange(input, 'b n c -> b (c n)')  # it's different from the above one
print(output, output.shape)

output = repeat(input, 'b n c -> b (c n 3)')  # it's different from the above one
print(output, output.shape)

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]]) torch.Size([2, 3, 2])
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]]) torch.Size([2, 6])
tensor([[ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5],
        [ 6,  6,  7,  7,  8,  8,  9,  9, 10, 10, 11, 11]]) torch.Size([2, 12])
tensor([[ 0,  2,  4,  1,  3,  5],
        [ 6,  8, 10,  7,  9, 11]]) torch.Size([2, 6])
tensor([[ 0,  0,  0,  2,  2,  2,  4,  4,  4,  1,  1,  1,  3,  3,  3,  5,  5,  5],
        [ 6,  6,  6,  8,  8,  8, 10, 10, 10,  7,  7,  7,  9,  9,  9, 11, 11, 11]]) torch.Size([2, 18])


In [8]:
import torch
from einops import repeat, rearrange

# input = torch.randn(4, 2)
# print(input, input.shape)

# output = rearrange(input, '(b n) c -> b n c', b=2)
# print(output, output.shape)

# output = input.reshape(2, 2, 2)
# print(output, output.shape)

input = torch.arange(12).reshape(6, 2)
print(input, input.shape)

output = rearrange(input, '(b n) c -> b n c', b=3)
print(output, output.shape)

output = rearrange(input, '(n b) c -> n b c', b=3)
print(output, output.shape)

output = rearrange(input, '(n b) c -> b n c', b=3)
print(output, output.shape)

tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11]]) torch.Size([6, 2])
tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]]]) torch.Size([3, 2, 2])
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]]) torch.Size([2, 3, 2])
tensor([[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]]) torch.Size([3, 2, 2])


### expand dimensions

In [3]:
import torch
from einops import repeat, rearrange

input = torch.randn(3)
print(input, input.shape)

output = repeat(input, 'h -> () h () ()')
print(output, output.shape)

tensor([-0.2154, -1.8071, -1.5535]) torch.Size([3])
tensor([[[[-0.2154]],

         [[-1.8071]],

         [[-1.5535]]]]) torch.Size([1, 3, 1, 1])


In [4]:
from einops import rearrange

def reshape_first_two_dims(input_tensor):
    # Get the number of dimensions in the input tensor
    num_dims = input_tensor.ndim

    if num_dims < 2:
        # Handle cases where the tensor has fewer than 2 dimensions
        raise ValueError("Input tensor must have at least 2 dimensions.")

    # Create the new shape pattern based on the number of dimensions
    new_shape_pattern = f'{" ".join(["dim"] * num_dims)} -> dim {" ".join(["dim"] * (num_dims - 2))}'
    print(new_shape_pattern)

    # Reshape the tensor
    reshaped_tensor = rearrange(input_tensor, new_shape_pattern)

    return reshaped_tensor

# Example usage:
import numpy as np
import torch

# Create tensors with varying dimensions
tensor1 = np.random.rand(4, 5)
tensor2 = np.random.rand(3, 4, 5)
tensor3 = torch.randn(2, 3, 4, 5)

reshaped_tensor1 = reshape_first_two_dims(tensor1)
reshaped_tensor2 = reshape_first_two_dims(tensor2)
reshaped_tensor3 = reshape_first_two_dims(tensor3)

print(reshaped_tensor1.shape)
print(reshaped_tensor2.shape)
print(reshaped_tensor3.shape)

dim dim -> dim 


EinopsError:  Error while processing rearrange-reduction pattern "dim dim -> dim ".
 Input tensor shape: (4, 5). Additional info: {}.
 Indexing expression contains duplicate dimension "dim"

In [10]:
import torch

from einops import rearrange, repeat

data = torch.arange(24).reshape(2, 3, 4)
print(data)

# data = rearrange(data, 'b n c -> (b n) c')
# data = rearrange(data, 'b n c -> b (n c)')
data = rearrange(data, 'b n c -> (n c) b')
# data = rearrange(data, 'b n c -> b (c n)')
print(data, data.shape)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[ 0, 12],
        [ 1, 13],
        [ 2, 14],
        [ 3, 15],
        [ 4, 16],
        [ 5, 17],
        [ 6, 18],
        [ 7, 19],
        [ 8, 20],
        [ 9, 21],
        [10, 22],
        [11, 23]]) torch.Size([12, 2])


In [3]:
input = torch.rand(5, 256, 12)
print(input.shape)

output = rearrange(input, 'b n (r c) -> b n r c', c=3)
print(output.shape)


torch.Size([5, 256, 12])
torch.Size([5, 256, 4, 3])


In [4]:
import torch
from einops import rearrange

input = torch.randn(2, 5, 4, 3)

input = rearrange(input, 'b g m c -> (b g) m c')
print(input.shape)

torch.Size([10, 4, 3])


In [5]:
import torch
from einops import rearrange

input = torch.randn(1, 262144, 3)

output = rearrange(input, 'x (b n) c -> (x b) n c', b=128)
print(output.shape)

torch.Size([128, 2048, 3])
