<a href="https://colab.research.google.com/github/vulcan-332/mini_einops/blob/main/Einops.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Define Pattern Parser

In [None]:
import re
from typing import List, Tuple

def parse_pattern(pattern: str) -> Tuple[List[str], List[str]]:
  if '->' not in pattern:
    raise ValueError("Pattern must contain '->' to separate left and right parts.")
  left_part, right_part = pattern.split('->') # use the split function to find the '->' symbols and split the string in left and right
  left_part = left_part.strip()
  right_part = right_part.strip()

  left_tokens = _tokenize(left_part)
  right_tokens = _tokenize(right_part)

  return left_tokens, right_tokens

def _tokenize(part: str) -> List[str]:
  pattern = r'\([^)]*\)|\S+'
  return re.findall(pattern, part)
parse_pattern('(h w) c -> h w c')

(['(h w)', 'c'], ['h', 'w', 'c'])

Input Side

In [None]:
from typing import Dict, List
import numpy as np

def infer_left_axes(left_tokens: List[str], tensor: np.ndarray, axes_lengths: Dict[str, int]) -> List[int]:
    """
    Given the left tokens (e.g. ["b", "(c h)", "...", "w"]) and tensor of shape (2,3,4,5,6),
    and possibly user-specified axes_lengths,
    return a fully expanded list of dimension sizes (e.g. [2, 3, 4, 5, 6]) in the correct order.

    For merges like "(c h)" we handle them if needed (in a minimal approach,
    you might handle them after reshaping; but let's show one way).
    We'll do partial logic here; you may refine in your final version.
    """
    shape = tensor.shape
    n_dims = len(shape)
    dim_ptr = 0  # pointer into shape

    # We'll accumulate each dimension's size in a list:
    final_expanded_dims = []

    for token in left_tokens:
        if token == '...':
            # The ellipsis means "take however many dimensions remain" => potentially multiple dims
            # We take all leftover dims from shape
            leftover = shape[dim_ptr:]
            final_expanded_dims.extend(leftover)
            dim_ptr = n_dims  # we've used up all dims
        elif token.startswith('(') and token.endswith(')'):
            # Merged/split pattern like "(c h)"
            # we interpret this as a single dimension from shape[dim_ptr], and try to split it.
            inside = token[1:-1].strip()  # "c h"
            subdims = inside.split()      # ["c", "h"]

            # The dimension we have is shape[dim_ptr]
            merged_dim_size = shape[dim_ptr]
            dim_ptr += 1

            # We'll distribute that dimension among subdims c and h.
            # Suppose c=3, h=4 => c*h=12 => must match merged_dim_size.
            # If user provided c in axes_lengths, we infer h, etc.
            known_product = 1
            unknown_subdims = []
            for sd in subdims:
                if sd in axes_lengths:
                    known_product *= axes_lengths[sd]
                else:
                    unknown_subdims.append(sd)

            leftover = merged_dim_size // known_product
            if len(unknown_subdims) > 1:
                # If more than one is unknown, we need a more advanced solver. For now we keep it simple.
                raise ValueError("Too many unknown dims in the merged group. Provide enough axes_lengths.")

            if len(unknown_subdims) == 1:
                # That subdim = leftover
                axes_lengths[unknown_subdims[0]] = leftover

            # Now we know each subdim's size:
            for sd in subdims:
                final_expanded_dims.append(axes_lengths[sd])

        else:
            # A normal axis name like 'b'
            if dim_ptr >= n_dims:
                raise ValueError(f"Pattern expects more dims than input shape provides. Problem with token '{token}'")

            current_dim_size = shape[dim_ptr]
            dim_ptr += 1

            # Check if user specified a constraint for this axis name
            if token in axes_lengths:
                if axes_lengths[token] != current_dim_size:
                    raise ValueError(
                        f"Mismatch for axis '{token}': pattern expects {axes_lengths[token]}, but tensor has {current_dim_size}."
                    )
            else:
                axes_lengths[token] = current_dim_size

            final_expanded_dims.append(current_dim_size)

    if dim_ptr < n_dims:
        # leftover dims not covered by pattern => error or assume user forgot ellipsis
        raise ValueError(f"Not all dimensions in input tensor are covered by the left pattern. "
                         f"Unused dims start at index {dim_ptr}")

    return final_expanded_dims


Output side

In [None]:
# Place holder function. Actual functionality later defined in rearrange() later on.

def infer_right_axes(
    right_tokens: List[str],
    axes_lengths: Dict[str, int],
    expanded_input_ndims: int
):
    """
    From the right tokens (e.g. ["b", "...", "(c h)", "w"]) figure out:
      - The order in which to read existing axes (transposition).
      - Which merges are needed.
      - If we have repeating or splitting instructions.
    We'll return a structure describing final output dims and a permutation of the input dims.
    """
    # Pseudocode approach:
    # 1. We'll build a list of "output instructions" that say how to form each new axis.
    # 2. If we see an ellipsis, that means "take some number of axes in the same order they appear".
    # 3. If we see parenthesis "(c h)", that means "merge c and h" or "split" depending on context.
    #

    reorder_map = []
    merges = []
    repeats = []

    # We'll assume for now that each named token on the right corresponds to an existing axis from left
    # in the same order. A real solution would need a better mapping from names -> indices.

    current_in_axis = 0
    # ...
    # TOTALLY simplified approach here, as a placeholder
    return reorder_map, merges, repeats


Rearrange function


In [None]:
from typing import Any
import numpy as np

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: Any) -> np.ndarray:
    left_tokens, right_tokens = parse_pattern(pattern)
    left_expanded = infer_left_axes(left_tokens, tensor, axes_lengths)
    tensor = tensor.reshape(*left_expanded)

    # Flatten left side into axis names
    input_axes = []
    for token in left_tokens:
        if token.startswith('(') and token.endswith(')'):
            input_axes.extend(token[1:-1].split())
        elif token == '...':
            raise NotImplementedError("Ellipsis not yet supported")
        else:
            input_axes.append(token)

    axis_sizes = [axes_lengths[name] for name in input_axes]
    axis_name_to_idx = {name: i for i, name in enumerate(input_axes)}

    # Flatten right side into axis names
    flat_right_names = []
    for token in right_tokens:
        if token.startswith('(') and token.endswith(')'):
            flat_right_names.extend(token[1:-1].split())
        elif token == '...':
            raise NotImplementedError("Ellipsis not yet supported")
        else:
            flat_right_names.append(token)

    # Transpose to match flat_right_names order
    perm = [input_axes.index(name) for name in flat_right_names if name in input_axes]
    tensor = np.transpose(tensor, perm)

    # Build shape with 1s for any new axes (e.g. for repeat)
    final_shape = []
    for token in right_tokens:
        if token.startswith('(') and token.endswith(')'):
            names = token[1:-1].split()
            size = 1
            for name in names:
                if name in input_axes:
                    size *= axes_lengths[name]
                elif name in axes_lengths:
                    size *= 1  # we insert a dummy for now
            final_shape.append(size)
        elif token in axes_lengths and token not in input_axes:
            final_shape.append(1)  # placeholder for repeat
        else:
            final_shape.append(axes_lengths[token])

    tensor = tensor.reshape(*final_shape)

    # Apply repeat to newly introduced axes
    for axis, token in enumerate(right_tokens):
        if token in axes_lengths and token not in input_axes:
            tensor = np.repeat(tensor, repeats=axes_lengths[token], axis=axis)

    return tensor


Unit Test

In [None]:
import unittest

class TestRearrange(unittest.TestCase):

  def test_basic_transpose(self):
    x = np.random.randn(2, 3, 4)
    y = rearrange(x, "b c h -> h c b")
    self.assertEqual(y.shape, (4, 3, 2))

  def test_transpose_middle(self):
      x = np.random.randn(2, 3, 4)
      y = rearrange(x, "b c h -> b h c")
      self.assertEqual(y.shape, (2, 4, 3))

  def test_merge_dims(self):
      x = np.random.randn(2, 3, 4)
      y = rearrange(x, "b c h -> b (c h)")
      self.assertEqual(y.shape, (2, 12))

  def test_split_dims(self):
      x = np.random.randn(2, 12, 3)
      y = rearrange(x, "b (c h) w -> b c h w", c=4)
      self.assertEqual(y.shape, (2, 4, 3, 3))

  def test_repeat_axis(self):
      x = np.random.randn(2, 3)
      y = rearrange(x, "b c -> b c r", r=2)
      self.assertEqual(y.shape, (2, 3, 2))

suite = unittest.TestLoader().loadTestsFromTestCase(TestRearrange)
unittest.TextTestRunner(verbosity=2).run(suite)

test_basic_transpose (__main__.TestRearrange.test_basic_transpose) ... ok
test_merge_dims (__main__.TestRearrange.test_merge_dims) ... ok
test_repeat_axis (__main__.TestRearrange.test_repeat_axis) ... ok
test_split_dims (__main__.TestRearrange.test_split_dims) ... ok
test_transpose_middle (__main__.TestRearrange.test_transpose_middle) ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.011s

OK


<unittest.runner.TextTestResult run=5 errors=0 failures=0>