<a href="https://colab.research.google.com/github/rokosbasilisk/einops_assignment/blob/main/Rearrange_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Lark for parsing the einops expressions
!pip install lark numpy



In [1]:
import numpy as np
from lark import Lark, Transformer

In [2]:
# https://github.com/lark-parser/lark/blob/master/docs/how_to_use.md

ELLIPSIS_TOKEN = "<<<ELLIPSIS>>>"

einops_grammar = r"""
    start: axes "->" axes        # transformation from input axes to output axes
    axes: axis*                  # a list of zero or more axes
    axis: NAME | ONE | group | ELLIPSIS   # an axis can be a name, '1', a grouped axis, or '...'
    group: "(" axes ")"          # a group is a parenthesized list of axes
    ELLIPSIS: "..."              # literal ellipsis
    ONE: "1"                     # literal '1' (singleton dimension)
    %import common.CNAME -> NAME   # import valid Python-like names as axis names
    %import common.WS              # import whitespace
    %ignore WS                    # ignore all whitespace in the input
"""


In [3]:
class PatternTransformer(Transformer):
    def start(self, items):
        return {"input": items[0], "output": items[1]}

    def axes(self, items):
        return items

    def axis(self, items):
        return items[0]

    def group(self, items):
        # Flatten the group content into a tuple
        if len(items) == 1 and isinstance(items[0], list):
            return tuple(items[0])
        return tuple(items)

    def NAME(self, token):
        return str(token)

    def ONE(self, token):
        return '1'

    def ELLIPSIS(self, token):
        return ELLIPSIS_TOKEN

In [4]:
## TestCases for LARK Pattern Parser

# patterns are created one for each case to be checked based on the requirements doc.

patterns = [
    ("h w -> w h", ['h', 'w'], ['w', 'h']),                                # Transpose
    ("(h w) c -> h w c", [('h', 'w'), 'c'], ['h', 'w', 'c']),              # Split
    ("a b c -> (a b) c", ['a', 'b', 'c'], [('a', 'b'), 'c']),              # Merge
    ("a b c -> a b 1 c", ['a', 'b', 'c'], ['a', 'b', '1', 'c']),           # Repeat
    ("... h w -> ... (h w)", ['<<<ELLIPSIS>>>', 'h', 'w'], ['<<<ELLIPSIS>>>', ('h', 'w')]) # Ellipsis

]


for pattern_str, expected_input, expected_output in patterns:
    parsed = PatternTransformer().transform(Lark(einops_grammar).parse(pattern_str))
    print(f"Testing pattern: {pattern_str}")
    assert parsed['input'] == expected_input, f"Input mismatch: {parsed['input']} != {expected_input}"
    assert parsed['output'] == expected_output, f"Output mismatch: {parsed['output']} != {expected_output}"
    print("✓ Passed\n")


Testing pattern: h w -> w h
✓ Passed

Testing pattern: (h w) c -> h w c
✓ Passed

Testing pattern: a b c -> (a b) c
✓ Passed

Testing pattern: a b c -> a b 1 c
✓ Passed

Testing pattern: ... h w -> ... (h w)
✓ Passed



In [5]:
class EinopsRearranger:
    """
    EinopsRearranger implements a lightweight version of einops.rearrange using:
    - A Lark grammar parser to parse patterns like 'a b (c d) -> (a c) b d'
    - Recursive flattening logic to handle grouped axes
    - Axis inference from tensor shape or user-provided kwargs
    - Ellipsis support to absorb multiple unspecified dimensions
    - Internal debug printouts controlled via `debug` flag
    - Pure numpy implementation, not supporting backprop stuff or torch
    """

    def __init__(self, debug=False):
        self.parser = Lark(einops_grammar)
        self.debug = debug

    def rearrange(self, tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
        if self.debug:
            print("\n[REARRANGE START]")
            print(f"  Tensor shape: {list(tensor.shape)}")
            print(f"  Pattern: {pattern}")
            print(f"  User axes_lengths: {axes_lengths}")

        parsed = PatternTransformer().transform(self.parser.parse(pattern))
        grouped_input, grouped_output = parsed['input'], parsed['output']
        if self.debug:
            print(f"  Grouped input: {grouped_input}")
            print(f"  Grouped output: {grouped_output}")

        axis_map = dict(axes_lengths)
        self._assign_top_level_dims_ellipsis(tensor, grouped_input, axis_map)
        if self.debug:
            print(f"  After top-level assignment, axis_map: {axis_map}")

        tensor = self._reshape_grouped_input(tensor, grouped_input, axis_map)
        if self.debug:
            print(f"  After input reshape: {tensor.shape}")

        tensor = self._maybe_transpose(tensor, grouped_input, grouped_output)
        if self.debug:
            print(f"  After transpose: {tensor.shape}")

        tensor = self._reshape_grouped_output(tensor, grouped_output, axis_map)
        if self.debug:
            print(f"  Final output shape: {tensor.shape}")
        return tensor

    def _assign_top_level_dims_ellipsis(self, tensor, pattern, axis_map):
        in_shape = list(tensor.shape)
        shape_len = len(in_shape)
        top_axes_no_ellipsis = [ax for ax in pattern if ax != ELLIPSIS_TOKEN]
        shape_idx = 0
        ellipsis_count = sum(1 for ax in pattern if ax == ELLIPSIS_TOKEN)

        if ellipsis_count > 1:
            raise ValueError("Only one ellipsis supported.")
        leftover_for_ellipsis = shape_len - (len(pattern) - 1) if ellipsis_count else 0
        if leftover_for_ellipsis < 0:
            raise ValueError("Not enough dims for ellipsis.")

        for top_axis in pattern:
            if top_axis == ELLIPSIS_TOKEN:
                axis_map['...'] = in_shape[shape_idx: shape_idx + leftover_for_ellipsis]
                if self.debug:
                    print(f"  Ellipsis => leftover dims {axis_map['...']}")
                shape_idx += leftover_for_ellipsis
            elif isinstance(top_axis, tuple):
                total_dim = in_shape[shape_idx]
                sub_names = list(top_axis)
                known_prod = 1
                unknowns = []
                for nm in sub_names:
                    if nm in axis_map:
                        known_prod *= axis_map[nm]
                    else:
                        unknowns.append(nm)
                if len(unknowns) == 1:
                    if total_dim % known_prod != 0:
                        raise ValueError(f"Can't split {total_dim} among group {top_axis}")
                    axis_map[unknowns[0]] = total_dim // known_prod
                elif len(unknowns) > 1:
                    raise ValueError(f"Too many unknown dims in group {top_axis}")
                shape_idx += 1
            else:
                if top_axis != '1' and top_axis not in axis_map:
                    axis_map[top_axis] = in_shape[shape_idx]
                shape_idx += 1

    def _reshape_grouped_input(self, tensor, pattern, axis_map):
        new_dims = []
        shape_idx = 0
        leftover_ellipsis = axis_map.get('...', None)

        for top_axis in pattern:
            if top_axis == ELLIPSIS_TOKEN:
                if leftover_ellipsis is None:
                    raise ValueError("Ellipsis mismatch in input reshape.")
                if self.debug:
                    print(f"  Expanding ellipsis with leftover dims {leftover_ellipsis}")
                new_dims.extend(leftover_ellipsis)
            elif isinstance(top_axis, tuple):
                total_dim = tensor.shape[shape_idx]
                dims = []
                known_prod = 1
                unknowns = []
                for s in top_axis:
                    if s in axis_map:
                        dims.append(axis_map[s])
                        known_prod *= axis_map[s]
                    else:
                        unknowns.append(s)
                if len(unknowns) == 1:
                    if total_dim % known_prod != 0:
                        raise ValueError(f"Cannot split dimension {total_dim} for group {top_axis}")
                    leftover = total_dim // known_prod
                    axis_map[unknowns[0]] = leftover
                    dims.append(leftover)
                elif len(unknowns) > 1:
                    raise ValueError(f"Too many unknowns in group {top_axis}")
                if self.debug:
                    print(f"  Splitting group {top_axis} => dims={dims}")
                new_dims.extend(dims)
                shape_idx += 1
            else:
                nm = str(top_axis)
                if nm == '1':
                    if self.debug:
                        print("  Found '1' -> new axis of size 1 (input).")
                    new_dims.append(1)
                else:
                    new_dims.append(axis_map[nm])
                shape_idx += 1

        if self.debug:
            print(f"  Input new_dims: {new_dims}")
        return tensor.reshape(new_dims)

    def _maybe_transpose(self, tensor, grouped_input, grouped_output):
        flat_in = []
        for x in grouped_input:
            if isinstance(x, tuple):
                flat_in.extend(map(str, x))
            elif x != ELLIPSIS_TOKEN:
                flat_in.append(str(x))

        flat_out = []
        for x in grouped_output:
            if isinstance(x, tuple):
                flat_out.extend(map(str, x))
            elif x != ELLIPSIS_TOKEN:
                flat_out.append(str(x))

        if self.debug:
            print(f"[Transpose check]\n  in= {flat_in}\n  out={flat_out}")
        if flat_in == flat_out:
            if self.debug:
                print("  No transpose needed.")
            return tensor

        if sorted(flat_in) != sorted(flat_out):
            if self.debug:
                print("  Mismatch in axis sets => skipping transpose.")
            return tensor

        perm = [flat_in.index(ax) for ax in flat_out]
        if self.debug:
            print(f"  Applying transpose perm: {perm}")
        return tensor.transpose(perm)

    def _reshape_grouped_output(self, tensor, pattern, axis_map):
        leftover_ellipsis = axis_map.get('...', [])
        new_dims = []
        shape_idx = 0

        for top_axis in pattern:
            if top_axis == ELLIPSIS_TOKEN:
                if self.debug:
                    print(f"  Expanding output ellipsis => {leftover_ellipsis}")
                new_dims.extend(leftover_ellipsis)
            elif isinstance(top_axis, tuple):
                product = 1
                for s in top_axis:
                    product *= axis_map[s]
                if self.debug:
                    print(f"  Merging group {top_axis} => product={product}")
                new_dims.append(product)
                shape_idx += 1
            else:
                nm = str(top_axis)
                if nm == '1':
                    if self.debug:
                        print("  Found '1' => new axis of size 1 (output).")
                    new_dims.append(1)
                else:
                    if nm not in axis_map:
                        if shape_idx < len(tensor.shape):
                            axis_map[nm] = tensor.shape[shape_idx]
                        else:
                            axis_map[nm] = 1
                    new_dims.append(axis_map[nm])
                shape_idx += 1

        if self.debug:
            print(f"  Output new_dims: {new_dims}")
        return tensor.reshape(new_dims)


### Test for all the five cases mentioned in the assignment document

In [6]:
if __name__ == "__main__":
    rearranger = EinopsRearranger(debug=True)

    # Transpose
    x = np.random.rand(3, 4)
    print("Transpose:", rearranger.rearrange(x, 'h w -> w h').shape)

    # Split
    x = np.random.rand(12, 10)
    print("Split:", rearranger.rearrange(x, '(h w) c -> h w c', h=3).shape)

    # Merge
    x = np.random.rand(3, 4, 5)
    print("Merge:", rearranger.rearrange(x, 'a b c -> (a b) c').shape)

    # Repeat
    x = np.random.rand(3, 1, 5)
    print("Repeat:", rearranger.rearrange(x, 'a b c -> a b 1 c').shape)

    # Ellipsis
    x = np.random.rand(2, 3, 4, 5)
    print("Ellipsis:", rearranger.rearrange(x, '... h w -> ... (h w)').shape)



[REARRANGE START]
  Tensor shape: [3, 4]
  Pattern: h w -> w h
  User axes_lengths: {}
  Grouped input: ['h', 'w']
  Grouped output: ['w', 'h']
  After top-level assignment, axis_map: {'h': 3, 'w': 4}
  Input new_dims: [3, 4]
  After input reshape: (3, 4)
[Transpose check]
  in= ['h', 'w']
  out=['w', 'h']
  Applying transpose perm: [1, 0]
  After transpose: (4, 3)
  Output new_dims: [4, 3]
  Final output shape: (4, 3)
Transpose: (4, 3)

[REARRANGE START]
  Tensor shape: [12, 10]
  Pattern: (h w) c -> h w c
  User axes_lengths: {'h': 3}
  Grouped input: [('h', 'w'), 'c']
  Grouped output: ['h', 'w', 'c']
  After top-level assignment, axis_map: {'h': 3, 'w': 4, 'c': 10}
  Splitting group ('h', 'w') => dims=[3, 4]
  Input new_dims: [3, 4, 10]
  After input reshape: (3, 4, 10)
[Transpose check]
  in= ['h', 'w', 'c']
  out=['h', 'w', 'c']
  No transpose needed.
  After transpose: (3, 4, 10)
  Output new_dims: [3, 4, 10]
  Final output shape: (3, 4, 10)
Split: (3, 4, 10)

[REARRANGE START]

### Testcases using Einops ###

In [21]:
!pip install einops



In [7]:
import numpy as np
from einops import rearrange as einops_rearrange

def test_transpose(einops_rearranger):
    x = np.random.rand(3, 4)
    my_res = einops_rearranger.rearrange(x, 'h w -> w h')
    ref_res = einops_rearrange(x, 'h w -> w h')
    print("Transpose shape (mine):", my_res.shape)
    print("Transpose shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape, "Shape mismatch!"
    assert np.allclose(my_res, ref_res), "Content mismatch!"

def test_split(einops_rearranger):
    x = np.random.rand(12, 10)
    my_res = einops_rearranger.rearrange(x, '(h w) c -> h w c', h=3)
    ref_res = einops_rearrange(x, '(h w) c -> h w c', h=3)
    print("Split shape (mine):", my_res.shape)
    print("Split shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape, "Shape mismatch!"
    assert np.allclose(my_res, ref_res), "Content mismatch!"

def test_merge(einops_rearranger):
    x = np.random.rand(3, 4, 5)
    my_res = einops_rearranger.rearrange(x, 'a b c -> (a b) c')
    ref_res = einops_rearrange(x, 'a b c -> (a b) c')
    print("Merge shape (mine):", my_res.shape)
    print("Merge shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape
    assert np.allclose(my_res, ref_res)

def test_repeat(einops_rearranger):
    x = np.random.rand(3, 1, 5)
    my_res = einops_rearranger.rearrange(x, 'a b c -> a b 1 c')
    ref_res = einops_rearrange(x, 'a b c -> a b 1 c')
    print("Repeat shape (mine):", my_res.shape)
    print("Repeat shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape
    assert np.allclose(my_res, ref_res)

def test_ellipsis(einops_rearranger):
    x = np.random.rand(2, 3, 4, 5)
    my_res = einops_rearranger.rearrange(x, '... h w -> ... (h w)')
    ref_res = einops_rearrange(x, '... h w -> ... (h w)')
    print("Ellipsis shape (mine):", my_res.shape)
    print("Ellipsis shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape
    assert np.allclose(my_res, ref_res)

if __name__ == "__main__":
    rearranger = EinopsRearranger()

    test_transpose(rearranger)
    test_split(rearranger)
    test_merge(rearranger)
    test_repeat(rearranger)
    test_ellipsis(rearranger)

    print("All tests comparing custom rearranger to einops have PASSED!")

Transpose shape (mine): (4, 3)
Transpose shape (einops): (4, 3)
Split shape (mine): (3, 4, 10)
Split shape (einops): (3, 4, 10)
Merge shape (mine): (12, 5)
Merge shape (einops): (12, 5)
Repeat shape (mine): (3, 1, 1, 5)
Repeat shape (einops): (3, 1, 1, 5)
Ellipsis shape (mine): (2, 3, 20)
Ellipsis shape (einops): (2, 3, 20)
All tests comparing custom rearranger to einops have PASSED!
