<a href="https://colab.research.google.com/github/rokosbasilisk/einops_assignment/blob/main/Rearrange_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Lark for parsing the einops expressions
!pip install lark numpy



In [3]:
import numpy as np
from lark import Lark, Transformer

In [4]:
ELLIPSIS_TOKEN = "<<<ELLIPSIS>>>"

einops_grammar = r"""
    start: axes "->" axes
    axes: axis*
    axis: NAME | ONE | group | ELLIPSIS
    group: "(" axes ")"
    ELLIPSIS: "..."
    ONE: "1"
    %import common.CNAME -> NAME
    %import common.WS
    %ignore WS
"""

In [5]:
class PatternTransformer(Transformer):
    def start(self, items):
        return {"input": items[0], "output": items[1]}

    def axes(self, items):
        return items

    def axis(self, items):
        return items[0]

    def group(self, items):
        # Flatten the group content into a tuple
        if len(items) == 1 and isinstance(items[0], list):
            return tuple(items[0])
        return tuple(items)

    def NAME(self, token):
        return str(token)

    def ONE(self, token):
        return '1'

    def ELLIPSIS(self, token):
        return ELLIPSIS_TOKEN

In [6]:
## TestCases for LARK Pattern Parser

# patterns are created one for each case to be checked based on the requirements doc.

patterns = [
    ("h w -> w h", ['h', 'w'], ['w', 'h']),                                # Transpose
    ("(h w) c -> h w c", [('h', 'w'), 'c'], ['h', 'w', 'c']),              # Split
    ("a b c -> (a b) c", ['a', 'b', 'c'], [('a', 'b'), 'c']),              # Merge
    ("a b c -> a b 1 c", ['a', 'b', 'c'], ['a', 'b', '1', 'c']),           # Repeat
    ("... h w -> ... (h w)", ['<<<ELLIPSIS>>>', 'h', 'w'], ['<<<ELLIPSIS>>>', ('h', 'w')]) # Ellipsis

]


for pattern_str, expected_input, expected_output in patterns:
    parsed = PatternTransformer().transform(Lark(einops_grammar).parse(pattern_str))
    print(f"Testing pattern: {pattern_str}")
    assert parsed['input'] == expected_input, f"Input mismatch: {parsed['input']} != {expected_input}"
    assert parsed['output'] == expected_output, f"Output mismatch: {parsed['output']} != {expected_output}"
    print("✓ Passed\n")


Testing pattern: h w -> w h
✓ Passed

Testing pattern: (h w) c -> h w c
✓ Passed

Testing pattern: a b c -> (a b) c
✓ Passed

Testing pattern: a b c -> a b 1 c
✓ Passed

Testing pattern: ... h w -> ... (h w)
✓ Passed



In [7]:
class EinopsRearranger:
    def __init__(self):
        self.parser = Lark(einops_grammar)

    def rearrange(self, tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
        print("\n[REARRANGE START]")
        shape = list(tensor.shape)
        print(f"  Tensor shape: {shape}")
        print(f"  Pattern: {pattern}")
        print(f"  User axes_lengths: {axes_lengths}")

        # Parse
        parsed = PatternTransformer().transform(self.parser.parse(pattern))
        grouped_input, grouped_output = parsed['input'], parsed['output']
        print(f"  Grouped input: {grouped_input}")
        print(f"  Grouped output: {grouped_output}")

        # Step 1: Build a basic axis_map from user lengths
        axis_map = dict(axes_lengths)

        # Step 2: Assign top-level dims, including ellipsis if any
        self._assign_top_level_dims_ellipsis(tensor, grouped_input, axis_map)
        print(f"  After top-level assignment, axis_map: {axis_map}")

        # Step 3: reshape input
        tensor = self._reshape_grouped_input(tensor, grouped_input, axis_map)
        print(f"  After input reshape: {tensor.shape}")

        # Step 4: transpose if needed
        tensor = self._maybe_transpose(tensor, grouped_input, grouped_output)
        print(f"  After transpose: {tensor.shape}")

        # Step 5: reshape output
        tensor = self._reshape_grouped_output(tensor, grouped_output, axis_map)
        print(f"  Final output shape: {tensor.shape}")
        return tensor

    def _assign_top_level_dims_ellipsis(self, tensor, pattern, axis_map):
        """
        We interpret top-level pattern axes. If we see '...', it can absorb multiple dims from 'tensor.shape'.
        For example, '... h w' in a 4D input => '...' = first 2 dims, 'h' => 3rd dim, 'w' => 4th dim.
        We'll store an entry axis_map['...'] = (list_of_dims).
        We'll also handle (h w) for splits, if we can partially infer them.
        """
        in_shape = list(tensor.shape)
        shape_len = len(in_shape)

        # Count how many axes are not ellipsis. We'll see how many dims are left for ellipsis.
        top_axes_no_ellipsis = [ax for ax in pattern if ax != ELLIPSIS_TOKEN]

        # We'll track shape_idx as we assign each top-level axis
        shape_idx = 0

        # figure out how many dims '...' must absorb
        # if pattern has N top-level axes including 1 ellipsis, we do shape_len - (N-1)
        # e.g. shape_len=4, pattern= ['...', 'h', 'w'] => '...' must absorb 2 dims
        ellipsis_count = sum(1 for ax in pattern if ax == ELLIPSIS_TOKEN)
        if ellipsis_count > 1:
            raise ValueError("Only one ellipsis supported in this simplified approach.")
        leftover_for_ellipsis = 0
        if ellipsis_count == 1:
            leftover_for_ellipsis = shape_len - (len(pattern) - 1)
            if leftover_for_ellipsis < 0:
                raise ValueError("Not enough dims for the ellipsis in input pattern.")

        # We'll do a pass: each top-level axis either gets assigned or if it's ellipsis,
        # it absorbs 'leftover_for_ellipsis' dims. If it's a group, we do partial inference, etc.
        for top_axis in pattern:
            if top_axis == ELLIPSIS_TOKEN:
                # consume leftover_for_ellipsis dims
                # We'll store them as axis_map['...'] = list_of_sizes
                # or skip storing if we prefer. We'll handle it in _reshape_grouped_input.
                axis_map['...'] = in_shape[shape_idx: shape_idx+leftover_for_ellipsis]
                print(f"  Ellipsis => leftover dims {axis_map['...']}")
                shape_idx += leftover_for_ellipsis

            elif isinstance(top_axis, tuple):
                # e.g. (h w)
                total_dim = in_shape[shape_idx]
                sub_names = list(top_axis)
                known_prod = 1
                unknowns = []
                for nm in sub_names:
                    if nm in axis_map:
                        known_prod *= axis_map[nm]
                    else:
                        unknowns.append(nm)
                if len(unknowns) == 1:
                    # infer
                    if total_dim % known_prod != 0:
                        raise ValueError(f"Can't split {total_dim} among group {top_axis}")
                    leftover = total_dim // known_prod
                    axis_map[unknowns[0]] = leftover
                elif len(unknowns) > 1:
                    raise ValueError(f"Too many unknown dims in group {top_axis}")
                shape_idx += 1
            else:
                # single name like 'c' or '1'
                if top_axis == '1':
                    # skip? means a new axis of size 1
                    pass
                elif top_axis not in axis_map:
                    # assign from shape
                    axis_map[top_axis] = in_shape[shape_idx]
                shape_idx += 1

    def _reshape_grouped_input(self, tensor, pattern, axis_map):
        """
        If we see '...', we insert that many leftover dims from the original shape.
        If we see (h w), we do a split, else single name => axis_map
        """
        new_dims = []
        shape_idx = 0

        # We'll track the leftover '...' dims if any
        leftover_ellipsis = axis_map.get('...', None)

        for top_axis in pattern:
            if top_axis == ELLIPSIS_TOKEN:
                # expand leftover_ellipsis into new_dims
                if leftover_ellipsis is None:
                    raise ValueError("Ellipsis mismatch in input reshape.")
                print(f"  Expanding ellipsis with leftover dims {leftover_ellipsis}")
                new_dims.extend(leftover_ellipsis)
            elif isinstance(top_axis, tuple):
                # group => splitted dimension
                total_dim = tensor.shape[shape_idx]
                sub_names = list(top_axis)
                dims = []
                known_prod = 1
                unknowns = []
                for s in sub_names:
                    if s in axis_map:
                        dims.append(axis_map[s])
                        known_prod *= axis_map[s]
                    else:
                        unknowns.append(s)
                if len(unknowns) == 1:
                    if total_dim % known_prod != 0:
                        raise ValueError(f"Cannot split dimension {total_dim} for group {top_axis}")
                    leftover = total_dim // known_prod
                    axis_map[unknowns[0]] = leftover
                    dims.append(leftover)
                elif len(unknowns) > 1:
                    raise ValueError(f"Too many unknowns in group {top_axis}")
                print(f"  Splitting group {top_axis} => dims={dims}")
                new_dims.extend(dims)
                shape_idx += 1
            else:
                nm = str(top_axis)
                if nm == '1':
                    print("  Found '1' -> new axis of size 1 (input).")
                    new_dims.append(1)
                else:
                    if nm not in axis_map:
                        raise ValueError(f"Missing axis size for '{nm}' in input reshape.")
                    new_dims.append(axis_map[nm])
                shape_idx += 1

        print(f"  Input new_dims: {new_dims}")
        reshaped = tensor.reshape(new_dims)
        return reshaped

    def _maybe_transpose(self, tensor, grouped_input, grouped_output):
        flat_in = []
        for x in grouped_input:
            if isinstance(x, tuple):
                flat_in.extend(map(str, x))
            elif x != ELLIPSIS_TOKEN:
                flat_in.append(str(x))

        flat_out = []
        for x in grouped_output:
            if isinstance(x, tuple):
                flat_out.extend(map(str, x))
            elif x != ELLIPSIS_TOKEN:
                flat_out.append(str(x))

        print(f"[Transpose check]\n  in= {flat_in}\n  out={flat_out}")
        if flat_in == flat_out:
            print("  No transpose needed.")
            return tensor

        # if sets mismatch => skip
        if sorted(flat_in) != sorted(flat_out):
            print("  Mismatch in axis sets => skipping transpose.")
            return tensor

        perm = [flat_in.index(ax) for ax in flat_out]
        print(f"  Applying transpose perm: {perm}")
        return tensor.transpose(perm)

    def _reshape_grouped_output(self, tensor, pattern, axis_map):
        """
        If '...', we expand leftover dims. If (h w) => merging. If '1' => new axis of size 1
        """
        leftover_ellipsis = axis_map.get('...', [])
        # we don't track shape_idx for ellipsis expansions in output
        new_dims = []
        shape_idx = 0

        for top_axis in pattern:
            if top_axis == ELLIPSIS_TOKEN:
                # expand leftover
                print(f"  Expanding output ellipsis => {leftover_ellipsis}")
                new_dims.extend(leftover_ellipsis)
            elif isinstance(top_axis, tuple):
                # merge
                sub_names = list(top_axis)
                product = 1
                for s in sub_names:
                    if s not in axis_map:
                        raise ValueError(f"No axis_map for {s} in output merge.")
                    product *= axis_map[s]
                print(f"  Merging group {top_axis} => product={product}")
                new_dims.append(product)
                shape_idx += 1
            else:
                nm = str(top_axis)
                if nm == '1':
                    print("  Found '1' => new axis of size 1 (output).")
                    new_dims.append(1)
                else:
                    if nm not in axis_map:
                        # maybe from tensor shape
                        if shape_idx < len(tensor.shape):
                            axis_map[nm] = tensor.shape[shape_idx]
                        else:
                            axis_map[nm] = 1
                    new_dims.append(axis_map[nm])
                shape_idx += 1

        print(f"  Output new_dims: {new_dims}")
        return tensor.reshape(new_dims)

In [8]:
# test for all the five cases
if __name__ == "__main__":
    rearranger = EinopsRearranger()

    # Transpose
    x = np.random.rand(3, 4)
    print("Transpose:", rearranger.rearrange(x, 'h w -> w h').shape)

    # Split
    x = np.random.rand(12, 10)
    print("Split:", rearranger.rearrange(x, '(h w) c -> h w c', h=3).shape)

    # Merge
    x = np.random.rand(3, 4, 5)
    print("Merge:", rearranger.rearrange(x, 'a b c -> (a b) c').shape)

    # Repeat
    x = np.random.rand(3, 1, 5)
    print("Repeat:", rearranger.rearrange(x, 'a b c -> a b 1 c').shape)

    # Ellipsis
    x = np.random.rand(2, 3, 4, 5)
    print("Ellipsis:", rearranger.rearrange(x, '... h w -> ... (h w)').shape)



[REARRANGE START]
  Tensor shape: [3, 4]
  Pattern: h w -> w h
  User axes_lengths: {}
  Grouped input: ['h', 'w']
  Grouped output: ['w', 'h']
  After top-level assignment, axis_map: {'h': 3, 'w': 4}
  Input new_dims: [3, 4]
  After input reshape: (3, 4)
[Transpose check]
  in= ['h', 'w']
  out=['w', 'h']
  Applying transpose perm: [1, 0]
  After transpose: (4, 3)
  Output new_dims: [4, 3]
  Final output shape: (4, 3)
Transpose: (4, 3)

[REARRANGE START]
  Tensor shape: [12, 10]
  Pattern: (h w) c -> h w c
  User axes_lengths: {'h': 3}
  Grouped input: [('h', 'w'), 'c']
  Grouped output: ['h', 'w', 'c']
  After top-level assignment, axis_map: {'h': 3, 'w': 4, 'c': 10}
  Splitting group ('h', 'w') => dims=[3, 4]
  Input new_dims: [3, 4, 10]
  After input reshape: (3, 4, 10)
[Transpose check]
  in= ['h', 'w', 'c']
  out=['h', 'w', 'c']
  No transpose needed.
  After transpose: (3, 4, 10)
  Output new_dims: [3, 4, 10]
  Final output shape: (3, 4, 10)
Split: (3, 4, 10)

[REARRANGE START]

In [9]:
#### Testcases using Einops ###
!pip install einops



In [10]:
import numpy as np
from einops import rearrange as einops_rearrange

def test_transpose(einops_rearranger):
    x = np.random.rand(3, 4)
    my_res = einops_rearranger.rearrange(x, 'h w -> w h')
    ref_res = einops_rearrange(x, 'h w -> w h')
    print("Transpose shape (mine):", my_res.shape)
    print("Transpose shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape, "Shape mismatch!"
    assert np.allclose(my_res, ref_res), "Content mismatch!"

def test_split(einops_rearranger):
    x = np.random.rand(12, 10)
    my_res = einops_rearranger.rearrange(x, '(h w) c -> h w c', h=3)
    ref_res = einops_rearrange(x, '(h w) c -> h w c', h=3)
    print("Split shape (mine):", my_res.shape)
    print("Split shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape, "Shape mismatch!"
    assert np.allclose(my_res, ref_res), "Content mismatch!"

def test_merge(einops_rearranger):
    x = np.random.rand(3, 4, 5)
    my_res = einops_rearranger.rearrange(x, 'a b c -> (a b) c')
    ref_res = einops_rearrange(x, 'a b c -> (a b) c')
    print("Merge shape (mine):", my_res.shape)
    print("Merge shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape
    assert np.allclose(my_res, ref_res)

def test_repeat(einops_rearranger):
    x = np.random.rand(3, 1, 5)
    my_res = einops_rearranger.rearrange(x, 'a b c -> a b 1 c')
    ref_res = einops_rearrange(x, 'a b c -> a b 1 c')
    print("Repeat shape (mine):", my_res.shape)
    print("Repeat shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape
    assert np.allclose(my_res, ref_res)

def test_ellipsis(einops_rearranger):
    x = np.random.rand(2, 3, 4, 5)
    my_res = einops_rearranger.rearrange(x, '... h w -> ... (h w)')
    ref_res = einops_rearrange(x, '... h w -> ... (h w)')
    print("Ellipsis shape (mine):", my_res.shape)
    print("Ellipsis shape (einops):", ref_res.shape)
    assert my_res.shape == ref_res.shape
    assert np.allclose(my_res, ref_res)

if __name__ == "__main__":
    rearranger = EinopsRearranger()

    test_transpose(rearranger)
    test_split(rearranger)
    test_merge(rearranger)
    test_repeat(rearranger)
    test_ellipsis(rearranger)

    print("All tests comparing custom rearranger to einops have PASSED!")


[REARRANGE START]
  Tensor shape: [3, 4]
  Pattern: h w -> w h
  User axes_lengths: {}
  Grouped input: ['h', 'w']
  Grouped output: ['w', 'h']
  After top-level assignment, axis_map: {'h': 3, 'w': 4}
  Input new_dims: [3, 4]
  After input reshape: (3, 4)
[Transpose check]
  in= ['h', 'w']
  out=['w', 'h']
  Applying transpose perm: [1, 0]
  After transpose: (4, 3)
  Output new_dims: [4, 3]
  Final output shape: (4, 3)
Transpose shape (mine): (4, 3)
Transpose shape (einops): (4, 3)

[REARRANGE START]
  Tensor shape: [12, 10]
  Pattern: (h w) c -> h w c
  User axes_lengths: {'h': 3}
  Grouped input: [('h', 'w'), 'c']
  Grouped output: ['h', 'w', 'c']
  After top-level assignment, axis_map: {'h': 3, 'w': 4, 'c': 10}
  Splitting group ('h', 'w') => dims=[3, 4]
  Input new_dims: [3, 4, 10]
  After input reshape: (3, 4, 10)
[Transpose check]
  in= ['h', 'w', 'c']
  out=['h', 'w', 'c']
  No transpose needed.
  After transpose: (3, 4, 10)
  Output new_dims: [3, 4, 10]
  Final output shape: (