<a href="https://colab.research.google.com/github/rraghavkaushik/Sarvam-Assignment/blob/main/notebooks/einops_rearrange_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quick Guide for the different sections

### There are three sections in this notebook,
### a) Implementation of einops from scratch
### b) Tests
### c) Comparison with einops

# Implementation of einops from scratch

In [1]:
import numpy as np

In [2]:
import re

In [3]:
def validate_pattern(pattern: str):

    if '->' not in pattern:
        raise ValueError("Pattern must contain '->'")

    lhs, rhs = pattern.split('->')

    for side, name in [(lhs, "input"), (rhs, "output")]:
        if side.count('(') != side.count(')'):
            raise ValueError(f"Unbalanced parentheses in {name} pattern: '{side}'")

        if side.count('...') > 1:
            raise ValueError(f"Too many ellipses in {name} pattern: '{side}'")

        invalid = re.findall(r'[^\w\s\.\(\)]', side)
        if invalid:
            raise ValueError(f"Invalid characters in {name} pattern: {set(invalid)}")

        axes = flatten_axes(find_axes(side))
        axes_wo_ellipsis = [ax for ax in axes if ax != '...']
        if len(axes_wo_ellipsis) != len(set(axes_wo_ellipsis)):
            raise ValueError(f"Duplicate axes found in {name} pattern: {axes_wo_ellipsis}")

    return True

In [4]:
def parse(pattern: str):
  if '->' not in pattern:
    raise ValueError("Pattern does not ->, make sure that the pattern is right")
  input = pattern.split('->')[0]
  output = pattern.split('->')[1]
  input = find_axes(input)
  output = find_axes(output)
  return input, output

In [5]:
def find_axes(pattern):
    return re.findall(r'\.\.\.|[\w]+|\([^\)]+\)', pattern)

In [6]:
def flatten_axes(axes):
    flat = []
    for ax in axes:
        if ax.startswith("(") and ax.endswith(")"):
            # flat += re.findall(r'\w+', ax)
            flat += re.findall(r'\w+|\.\.\.', ax)
        else:
            flat.append(ax)
    return flat

In [7]:
def transpose(tensor, input, output):
  for i in output:
    if i not in input:
      raise ValueError(f"Output axis '{ax}' not found in input axes")

  perm = [input.index(ax) for ax in output]
  return tensor.transpose(perm)


In [8]:
def is_transpose(pattern) -> bool:
  return ('...' not in pattern) and ('(' not in pattern) and (')' not in pattern)

In [9]:
from typing import List, Tuple, Dict

def split_axes(ip_ax: str, tensor_shape: int, axes_lengths: Dict[str, int], shape: Dict[str, int]) -> List[str]:

    axes = re.findall(r'\w+', ip_ax)
    known_ax = [i for i in axes if i in axes_lengths]
    known_dim = [axes_lengths[i] for i in known_ax]
    unknown_ax = [i for i in axes if i not in axes_lengths]

    prod_known = int(np.prod(known_dim)) if known_dim else 1

    if len(unknown_ax) == 0:
        sizes = [axes_lengths[i] for i in axes]
    elif len(unknown_ax) == 1:
        remaining = tensor_shape // prod_known
        axes_lengths[unknown_ax[0]] = remaining
        sizes = [axes_lengths[i] for i in axes]
    else:
        raise ValueError(
            f"Cannot infer sizes for multiple unknown axes {unknown_ax}. "
            f"Known axes: {known_ax}. Please specify all but one."
        )

    if np.prod(sizes) != tensor_shape:
        raise ValueError(
            f"Mismatch in shapes: cannot split dimension {tensor_shape} into {sizes} "
            f"(from axes: {axes})"
        )

    for i, s in zip(axes, sizes):
        shape[i] = s

    return axes


In [10]:
# def merge_axis(op_ax, axes_shape):

#   ax_split = re.findall(r'\w+', op_ax)
#   return int(np.prod([axes_shape[i] for i in ax_split]))

def merge_axis(op_ax, axes_shape, ellipsis_ax=[]):

    tokens = re.findall(r'\w+|\.{3}', op_ax)
    resolved_axes = []
    for token in tokens:
        if token == '...':
            resolved_axes.extend(ellipsis_ax)
        else:
            resolved_axes.append(token)
    return int(np.prod([axes_shape[i] for i in resolved_axes]))


In [11]:
def repeat_new_axes(tensor, input_axes, output, axes_shape, axes_lengths):
  new_ax = [i for i in output if i not in input_axes]
  # print(new_ax)

  for i in new_ax:
    if i.startswith('(') or i == '...':
      continue
    if i not in axes_lengths:
      raise ValueError(f"The axis {i} needs a specified length")

    tensor = np.expand_dims(tensor, axis=-1)
    tensor = np.repeat(tensor, axes_lengths[i], axis=-1)

    axes_shape[i] = axes_lengths[i]
    input_axes.append(i)

    target_index = output.index(i)
    tensor = np.moveaxis(tensor, -1, target_index)
    input_axes = input_axes[:-1]
    input_axes.insert(target_index, i)

  return tensor, input_axes

In [12]:
def ellipsis(input, tensor_shape, axes):
  # n = len(axes)
  n_ellipsis = len(tensor_shape) - len(axes)

  if n_ellipsis < 0:
    raise ValueError("Not enough dimensions in input tensor")

  ellipsis_dim = tensor_shape[:n_ellipsis]
  ellipsis_ax = [f'_ellipsis_{i}' for i in range(n_ellipsis)]
  return ellipsis_dim, ellipsis_ax, n_ellipsis

In [13]:
def extract_axes(expr, ellipsis_ax):

    '''Hint: added this to handle cases where ellipsis occur in the merge axis operation,
    and also when the case becomes too complex with the transpose part coming in as well'''
    expr = expr.strip("()")
    return [ax if ax != '...' else ell for ax in expr.split() for ell in (ellipsis_ax if ax == '...' else [ax])]


In [45]:
def rearrange(tensor: np.array, pattern: str, **axes_lengths) -> np.ndarray:
    input, output = parse(pattern)
    # print(input, output)

    # if is_transpose(pattern):
        # print(transpose(tensor, input, output))
        # return transpose(tensor, input, output)

    # print("...")
    pattern_axes = set(input + output)
    extra_keys = set(axes_lengths.keys()) - pattern_axes
    if extra_keys:
      raise ValueError(f"Extra keys in axes_lengths not used in pattern: {extra_keys}")

    axes_shape = {}
    input_axes = []
    # output_axes = []
    ellipsis_dim = []
    axes = []

    for ip_ax in input:
        if ip_ax == '...':
          continue
        elif ip_ax.startswith("(") and ip_ax.endswith(")"):
          axes += re.findall(r'\w+', ip_ax)
        else:
          axes.append(ip_ax)
    # print(axes)

    tensor_shape = list(tensor.shape)
    # print(tensor_shape)

    '''ellipsis_dim = []
    ellipsis_ax = []
    n_ellipsis = 0
    if '...' in pattern:
      ellipsis_dim, ellipsis_ax, n_ellipsis = ellipsis(input, tensor_shape, axes)'''
    # else:
    #   ellipsis_dim = []
    #   ellipsis_ax = []
    #   n_ellipsis = 0

    n_ellipsis = len(tensor_shape) - len(axes)
    if '...' in pattern:
      if n_ellipsis < 0:
        raise ValueError("Too many axoes in pattern for input tensor")

    ellipsis_ax = [f'_ellipsis_{i}' for i in range(n_ellipsis)]
    ellipsis_dim = tensor_shape.copy()
    if '...' in pattern:
      for i in axes:
        ind = tensor_shape.index(axes_lengths.get(i, None)) if i in axes_lengths else None
        if ind is not None:
          ellipsis_dim.pop(i)
    # idx = n_ellipsis
    idx = 0
    for ip_ax in input:
        if ip_ax == '...':
          # continue
          # if len( ellipsis_ax) > 0:
          # for ax, dim in zip(ellipsis_ax, ellipsis_dim): # changed ellipses_ax to ellipsis_ax
          #   axes_shape[ax] = dim
          # input_axes.extend(ellipsis_ax) # changed ellipses_ax to ellipsis_ax
          # else:
            # continue
          for i in range(n_ellipsis):
            axes_shape[ellipsis_ax[i]] = tensor_shape[idx]
            input_axes.append(ellipsis_ax[i])
            idx += 1

        elif ip_ax.startswith('(') and ip_ax.endswith(')'):
          ax_split = split_axes(ip_ax, tensor_shape[idx], axes_lengths, axes_shape)
          input_axes.extend(ax_split)
          idx += 1

        # else:
        #   axes_shape[ip_ax] = tensor_shape[idx]
        #   input_axes.append(ip_ax)
        #   idx += 1
        else:
          if idx < len(tensor_shape):
            axes_shape[ip_ax] = tensor_shape[idx]
            input_axes.append(ip_ax)
            idx += 1
          else:
            raise ValueError(f"Input pattern '{pattern}' expects more dimensions than the input tensor has.")

    # print(axes_shape)
    # print("Inferred Input Axes:", input_axes)
    # tensor = tensor.reshape()
    tensor_reshaped = tensor.reshape([axes_shape[i] for i in input_axes])
    # print(tensor_reshaped)

    new_axes_to_repeat = set(output) - set(input_axes)

    if new_axes_to_repeat:
      tensor_reshaped, input_axes = repeat_new_axes(tensor_reshaped, input_axes, output, axes_shape, axes_lengths)

    # print('...', input_axes, output)
    # print(tensor_reshaped, input_axes)

    final_axes_order = []
    for op_ax in output:
        if op_ax == '...':
            final_axes_order.extend(ellipsis_ax)
        elif op_ax.startswith('(') and op_ax.endswith(')'):
            final_axes_order.extend(extract_axes(op_ax, ellipsis_ax))
            # continue
        else:
            final_axes_order.append(op_ax)

    if input_axes != final_axes_order and not new_axes_to_repeat:
        perm = [input_axes.index(ax) for ax in final_axes_order]
        tensor_reshaped = tensor_reshaped.transpose(perm)
        input_axes = final_axes_order

    output_ax_dim = []
    for op_ax in output:
        if op_ax == '...':
          # continue
          output_ax_dim.extend([axes_shape[i] for i in ellipsis_ax])
        elif op_ax.startswith('(') and op_ax.endswith(')'):
          # output_ax_dim.append(merge_axis(op_ax, axes_shape))
          output_ax_dim.append(merge_axis(op_ax, axes_shape, ellipsis_ax))
        else:
          output_ax_dim.append(axes_shape[op_ax])

    # print("final output shape:", output_ax_dim)

    return tensor_reshaped.reshape(output_ax_dim)

# Tests

**Pattern: "a b c d -> (...) d" can't be handled by both einops and my implementation.**




In [52]:
import einops
x = np.random.rand(2, 3, 4, 5)
y = einops.rearrange(x, "a b c d -> (...) d")

y_ = rearrange(x, "a b c d -> (...) d")

print(y.shape, y_.shape)

EinopsError:  Error while processing rearrange-reduction pattern "a b c d -> (...) d".
 Input tensor shape: (2, 3, 4, 5). Additional info: {}.
 Ellipsis found in right side, but not left side of a pattern a b c d -> (...) d

**Case: Two split axes condition handling, with more than one of the axes being given in axes_shapes**

In [15]:
x = np.random.randn(6, 20)
x = rearrange(x, '(a b) (c d) -> a b c d', a=2, b =3, d=4)
print(x.shape)

['(a b)', '(c d)'] ['a', 'b', 'c', 'd']
(2, 3, 5, 4)


**Case: Transpose + ellipses + merge axes**

In [16]:
x = np.random.rand(2, 3, 4, 5, 6)
y5 = rearrange(x, "a b ... -> b (...) a")

['a', 'b', '...'] ['b', '(...)', 'a']


In [62]:
y5.shape

(3, 120, 2)

**Case: Only ellipsis condition**

In [17]:
y6 = rearrange(x, "... ->  (...)")
y6.shape

['...'] ['(...)']


(720,)

**Case: Merge multiple axes**

In [18]:
# x = np.random.rand(12, 5, 10)
y7 = rearrange(x, "a b c d e -> (a b c) (d e)")

['a', 'b', 'c', 'd', 'e'] ['(a b c)', '(d e)']


In [19]:
y7.shape

(24, 30)

**Case: Testing out time differences to run the code, when a seperate transpose() is used and not used to handle simple transpose cases**

In [20]:
# import time

# start = time.time()

x = np.random.rand(2, 12, 4)
result = rearrange(x, 'h w c -> (h w) c')

# end = time.time()

# print(f"time taken is {end - start} seconds")

['h', 'w', 'c'] ['(h w)', 'c']


for this case:

x = np.random.rand(2, 12, 4)

result = rearrange(x, 'h w c -> w h c', h=3)

With transpose() and without transpose(), the time difference is very minimum.

**Case: test for split and merge axes**

In [23]:
x = np.random.rand(12, 5)
result = rearrange(x, '(a b) c -> a (b c)', a=2)


['(a b)', 'c'] ['a', '(b c)']


**Case: test for repeating axes**

In [34]:
import einops

# Repeat an axis
x = np.random.rand(3, 1, 5)
result = rearrange(x, 'a 1 c -> a b c', b=2)
# r = einops.rearrange(x, 'a 1 c -> a b c', b =2)

['a', '1', 'c'] ['a', 'b', 'c']


In [35]:
result.shape

(3, 2, 5)

**Case: handle batch dimensions**

In [36]:
# Handle batch dimensions
x = np.random.rand(2, 3, 4, 5)
result = rearrange(x, '... h w -> ... (h w)')

['...', 'h', 'w'] ['...', '(h w)']


In [37]:
# Handle batch dimensions
x = np.random.rand(2, 3, 4, 5)
result = rearrange(x, 'h w ... -> (h w) ...')

['h', 'w', '...'] ['(h w)', '...']


In [38]:
result.shape

(6, 4, 5)

### Testing for cases of mismatch in output shapes

In [39]:
import torch

In [40]:
x = torch.arange(2 * 3 * 4 * 5 * 6).reshape(2, 3, 4, 5, 6)
y1 = rearrange(x, 'a b c d e -> (a b) c d e')
y1_expected = x.reshape(2 * 3, 4, 5, 6)
assert torch.equal(y1, y1_expected), "Test 1 failed"

['a', 'b', 'c', 'd', 'e'] ['(a b)', 'c', 'd', 'e']


In [41]:
y2 = rearrange(x, 'a b c d e -> a b (c d) e')
y2_expected = x.reshape(2, 3, 4 * 5, 6)
assert torch.equal(y2, y2_expected), "Test 2 failed"

['a', 'b', 'c', 'd', 'e'] ['a', 'b', '(c d)', 'e']


In [42]:
y3 = rearrange(x, 'a b c d e -> a b c d e')
assert torch.equal(y3, x), "Test 3 failed"

['a', 'b', 'c', 'd', 'e'] ['a', 'b', 'c', 'd', 'e']


In [43]:
def test():
  x = torch.arange(2 * 3 * 4 * 5 * 6).reshape(2, 3, 4, 5, 6)

  y1 = rearrange(x, "... c d e -> ... (c d) e")
  y1_expected = x.reshape(2, 3, 4 * 5, 6)
  assert torch.equal(y1, y1_expected), "Test 1 failed"

  print(y1.shape, y1_expected.shape)

  y2 = rearrange(x, 'a b c d e -> b (c d e) a')
  y2_ = rearrange(x, "a b ... -> b (...) a")
  assert torch.equal(y2, y2_), "Test 2 failed"

  print(y2.shape, y2_.shape)

  y3 = rearrange(x, "a b c d e -> b (a c d) e")
  y3_ = rearrange(x, "a b ... e -> b (a ...) e")
  assert torch.equal(y3, y3_), "Test 3 failed"

  print(y3.shape, y3_.shape)

In [44]:
test()

['...', 'c', 'd', 'e'] ['...', '(c d)', 'e']
torch.Size([2, 3, 20, 6]) torch.Size([2, 3, 20, 6])
['a', 'b', 'c', 'd', 'e'] ['b', '(c d e)', 'a']
['a', 'b', '...'] ['b', '(...)', 'a']
torch.Size([3, 120, 2]) torch.Size([3, 120, 2])
['a', 'b', 'c', 'd', 'e'] ['b', '(a c d)', 'e']
['a', 'b', '...', 'e'] ['b', '(a ...)', 'e']
torch.Size([3, 40, 6]) torch.Size([3, 40, 6])


### Test for different edge case input patterns

In [47]:
import numpy as np

invalid_patterns = [
    "a b -> (a b",
    "a b -> a * b",
    "(a b)) -> a b",
    "a b) -> (a b)",
    "a b -> a b)",
    "a b -> a c",
    "... a -> b ...",
    "(a b) -> c d",
    "(a b c) -> a b c",
    "a a -> a",
    "a b -> a b b",
    "(a b) c -> a b c",
    "(a b) -> a b",
    "... -> ... ...",
    "a b ... -> ... ... b",
    "a-b -> b",
    "a$ -> a",
    "123 -> 321",
]

def test_invalid_patterns():
    x = np.random.rand(2, 3, 4)

    for i, pattern in enumerate(invalid_patterns):
        try:
            print(f"Testing invalid pattern {i+1}: '{pattern}'")
            rearrange(x, pattern)
            raise AssertionError(f"Pattern {i+1} did not raise an error: '{pattern}'")
        except Exception as e:
            print(f"Caught expected error: {e.__class__.__name__}: {e}")

    print("All invalid patterns correctly raised errors!")

test_invalid_patterns()


Testing invalid pattern 1: 'a b -> (a b'
Caught expected error: ValueError: cannot reshape array of size 24 into shape (2,3)
Testing invalid pattern 2: 'a b -> a * b'
Caught expected error: ValueError: cannot reshape array of size 24 into shape (2,3)
Testing invalid pattern 3: '(a b)) -> a b'
Caught expected error: ValueError: Cannot infer sizes for multiple unknown axes ['a', 'b']. Known axes: []. Please specify all but one.
Testing invalid pattern 4: 'a b) -> (a b)'
Caught expected error: ValueError: cannot reshape array of size 24 into shape (2,3)
Testing invalid pattern 5: 'a b -> a b)'
Caught expected error: ValueError: cannot reshape array of size 24 into shape (2,3)
Testing invalid pattern 6: 'a b -> a c'
Caught expected error: ValueError: cannot reshape array of size 24 into shape (2,3)
Testing invalid pattern 7: '... a -> b ...'
Caught expected error: ValueError: The axis b needs a specified length
Testing invalid pattern 8: '(a b) -> c d'
Caught expected error: ValueError: Ca

### Test for edge case rearrange pattern cases

In [48]:
creative_rearrange_patterns = [
    ("a b -> b a"),

    ("a b c d -> (a b) (c d)"),
    ("(a b) (c d) -> a b c d"),

    ("... h w -> (...) (h w)"),
    ("b ... h w -> b (...) (h w)"),
    ("b ... h w -> (b ...) h w"),

    ("a b c d e -> e d c b a"),
    ("(a b) c (d e) -> a b d e c"),

    ("a b c d -> a (b c) d"),
    ("a b c d -> (a d) b c"),

    ("a 1 c -> a b c",),
    ("a 1 c -> a (1 c)"),

    ("(a b c) -> a b c",),
    ("(a b) (c d) -> a b c d"),

    ("... x y z -> ... (x y z)"),
    ("a ... b -> a (...) b"),

    ("a b c -> a b c"),
    ("... -> ..."),

    ("a b c d e -> (a b c) (d e)"),

    ("a b -> a b 1"),
    ("(a b) -> a b",),

    ("a b c d -> (...) d"),

    ("a b c d -> (a b c) d"),
    ("a b c d -> a (b c d)"),

    ("a b c d e f -> (a b) (c d) (e f)"),

    ("a b c d -> b c d a"),
    ("a b c d -> d a b c"),

    ("a b c d -> a (b d) c"),
]

def test_rearrange_patterns():
    x = np.random.rand(2, 3, 4, 5, 6)
    for i, pattern in enumerate(creative_rearrange_patterns):
        try:
            print(f"Testing pattern {i+1}: '{pattern}'")
            y = rearrange(x, pattern, a=2, b=3, c=4, d=5, e=6)
            print(f"Pattern {i+1} succeeded. Output shape: {y.shape}")
        except Exception as e:
            print(f"Pattern {i+1} failed with error: {e}")

test_rearrange_patterns()


Testing pattern 1: 'a b -> b a'
Pattern 1 failed with error: Extra keys in axes_lengths not used in pattern: {'c', 'd', 'e'}
Testing pattern 2: 'a b c d -> (a b) (c d)'
Pattern 2 failed with error: Extra keys in axes_lengths not used in pattern: {'e'}
Testing pattern 3: '(a b) (c d) -> a b c d'
Pattern 3 failed with error: Extra keys in axes_lengths not used in pattern: {'e'}
Testing pattern 4: '... h w -> (...) (h w)'
Pattern 4 failed with error: Extra keys in axes_lengths not used in pattern: {'b', 'c', 'd', 'a', 'e'}
Testing pattern 5: 'b ... h w -> b (...) (h w)'
Pattern 5 failed with error: Extra keys in axes_lengths not used in pattern: {'c', 'a', 'd', 'e'}
Testing pattern 6: 'b ... h w -> (b ...) h w'
Pattern 6 failed with error: Extra keys in axes_lengths not used in pattern: {'c', 'a', 'd', 'e'}
Testing pattern 7: 'a b c d e -> e d c b a'
Pattern 7 succeeded. Output shape: (6, 5, 4, 3, 2)
Testing pattern 8: '(a b) c (d e) -> a b d e c'
Pattern 8 failed with error: Mismatch in 

# Comparison with einops
(Cases taken from einops repo)



In [49]:
import einops

In [55]:
def compare(x, pattern, **axes_lengths):

  try:
    y = einops.rearrange(x, pattern, **axes_lengths)
    y_ = rearrange(x, pattern, **axes_lengths)
    if y.shape != y_.shape:
      print(f"Match in Shapes of y-{y.shape} vs y_-{y_.shape}")
    else:
      print("Shapes matches, test succesful!")
  except Exception as e1:
      try:
          rearrange(x, pattern, **axes_lengths)
          print(f"Custom function raised no error, but einops did: {e1}")
      except Exception:
              print(f"Both raised errors")

In [56]:
x = np.random.rand(2, 3, 4, 5, 6)
pattern = "...->..."
compare(x, pattern)

Shapes matches, test succesful!


In [57]:
compare(x, 'a b c d e-> a b c d e')

Shapes matches, test succesful!


In [58]:
compare(np.random.rand(2, 3, 4, 5, 6, 7), 'a b c d e ...-> ... a b c d e')

Shapes matches, test succesful!


In [60]:
compare(np.random.rand(2, 3, 4, 5, 6, 7), "a b c d e ...-> ... a b c d e")

Shapes matches, test succesful!


In [62]:
compare(np.random.rand(2, 3, 4 ,5, 6, 7), '... a b c d e -> ... a b c d e')

Shapes matches, test succesful!


In [63]:
compare(x, "a ... e-> a ... e")

Shapes matches, test succesful!


In [64]:
compare(x, "a ... -> a ... ")

Shapes matches, test succesful!


In [65]:
compare(x, "a ... c d e -> a (...) c d e")

Shapes matches, test succesful!


### pattern-wise comparison

In [68]:
def case_comp(x, pattern1, pattern2):
  y = rearrange(x, pattern1)
  y_ = rearrange(x, pattern2)
  if y.shape == y_.shape:
    print("results match!")
  else:
    print("Test failed... Mismatch in outputs")

In [67]:
p = ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... ")
compare(x, p[0])
compare(x, p[1])

Shapes matches, test succesful!
Shapes matches, test succesful!


In [69]:
p = ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e")
case_comp(x, p[0], p[1])

results match!


In [70]:
p = ("a b c d e -> a b c d e", "... -> ... ")
case_comp(x, p[0], p[1])

results match!


In [71]:
p = ("a b c d e -> (a b c d e)", "... ->  (...)")
case_comp(x, p[0], p[1])

results match!


In [72]:
p = ("a b c d e -> b (c d e) a", "a b ... -> b (...) a")
case_comp(x, p[0], p[1])

results match!


In [73]:
p = ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e")
case_comp(x, p[0], p[1])

results match!
