In [None]:
import numpy as np
import re
from typing import List, Dict, Tuple, Any, Optional

class EinopsError(ValueError):
    """Custom exception for errors during einops-like operations."""
    pass

In [None]:
def _parse_composition(expr: str) -> List[Dict[str, Any]]:
    """
    Parses a single side of the einops pattern (e.g., 'a b (c d)').
    """
    if not isinstance(expr, str):
        raise EinopsError(f"Expression must be a string, got {type(expr)}")

    expr = expr.strip()
    if not expr: return []

    if not re.fullmatch(r"[a-zA-Z0-9_()\.\s]*", expr):
        invalid_chars = set(re.sub(r"[a-zA-Z0-9_()\.\s]", "", expr))
        raise EinopsError(f"Invalid characters found in expression '{expr}': {invalid_chars}")

    if '()' in expr.replace(" ", ""):
         raise EinopsError(f"Empty parentheses '()' found in expression '{expr}'")

    parsed_components = []
    ellipsis_found = False
    current_pos = 0
    all_identifiers = set() # Track NAMED identifiers

    while current_pos < len(expr):
        expr_slice = expr[current_pos:]
        # <<< FIX: Regex Order Changed - (1) comes BEFORE ([a-zA-Z0-9_]+) >>>
        # Groups: 1:Ellipsis, 2:Group, 3:Literal 1, 4:Named Axis
        match = re.match(r"\s*(?:(\.\.\.)|(\([^\)]+\))|(1)|([a-zA-Z0-9_]+))\s*", expr_slice)

        if not match:
            unmatched_part = expr_slice.strip();
            if unmatched_part: raise EinopsError(f"Could not parse token at: '{unmatched_part[:20]}...' in expression '{expr}'. Check syntax.")
            else: break

        token_str = match.group(0).strip()
        component = None

        if match.group(1): # Ellipsis (Group 1)
             if ellipsis_found: raise EinopsError(f"Multiple ellipses found in '{expr}'")
             ellipsis_found = True; component = {'type': 'ellipsis'}
        elif match.group(2): # Potential Group (Group 2)
            group_expr = match.group(2)
            group_content = group_expr[1:-1].strip()
            if not group_content: raise EinopsError(f"Empty parentheses '()' found in '{expr}'")
            if not re.fullmatch(r"[a-zA-Z0-9_\s]+", group_content) or '(' in group_content or ')' in group_content:
                 if '(' in group_content or ')' in group_content: raise EinopsError(f"Invalid group content: '{group_content}'. Nested parentheses are not allowed within groups for rearrange.")
                 else: raise EinopsError(f"Invalid group content: '{group_content}'. Contains disallowed characters.")
            group_names = group_content.split()
            if len(set(group_names)) != len(group_names): raise EinopsError(f"Duplicate identifiers found within group: '{group_content}'")
            for name in group_names:
                 if name == '1': raise EinopsError(f"Literal '1' cannot be identifier in group: '{group_content}'")
                 if name in all_identifiers: raise EinopsError(f"Identifier '{name}' appears multiple times in '{expr}'")
                 all_identifiers.add(name)
            component = {'type': 'group', 'names': group_names}
        elif match.group(3): # Literal '1' (Group 3) - Matched before named axis now
             component = {'type': 'axis', 'name': '1'}
        elif match.group(4): # Axis Name (Group 4)
            name = match.group(4)
            if name in all_identifiers: raise EinopsError(f"Identifier '{name}' appears multiple times in '{expr}'")
            all_identifiers.add(name)
            component = {'type': 'axis', 'name': name}
        else: raise EinopsError(f"Internal parsing error near: {token_str}")

        if component: parsed_components.append(component)
        current_pos += match.end()

    return parsed_components


def _parse_pattern(pattern: str) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Parses the complete einops pattern string 'left -> right'.
    Performs basic cross-side validation.
    """
    if not isinstance(pattern, str): raise EinopsError(f"Pattern must be string, got {type(pattern)}")
    if '->' not in pattern: raise EinopsError(f"Pattern must contain '->', got: '{pattern}'")
    parts = pattern.split('->')
    if len(parts) != 2: raise EinopsError(f"Pattern must contain exactly one '->', got: '{pattern}'")
    left_expr, right_expr = parts[0], parts[1]

    try:
        parsed_left = _parse_composition(left_expr)
        parsed_right = _parse_composition(right_expr)
    except EinopsError as e: raise EinopsError(f"Error parsing expression in pattern '{pattern}': {e}") from e

    # --- Cross-Side Validation ---
    left_identifiers, left_has_ellipsis, num_anonymous_axes = set(), False, 0
    for comp in parsed_left:
        if comp['type'] == 'axis':
            if comp['name'] == '1': num_anonymous_axes += 1
            else: left_identifiers.add(comp['name'])
        elif comp['type'] == 'group': left_identifiers.update(comp['names'])
        elif comp['type'] == 'ellipsis': left_has_ellipsis = True

    right_identifiers, right_has_ellipsis, axes_requiring_lengths = set(), False, set()
    for comp in parsed_right:
        if comp['type'] == 'axis':
            name = comp['name']
            if name == '1': raise EinopsError(f"Literal '1' not allowed on RHS: '{pattern}'")
            if name not in left_identifiers: axes_requiring_lengths.add(name)
            right_identifiers.add(name)
        elif comp['type'] == 'group':
            group_names = comp['names']
            for name in group_names:
                 if name not in left_identifiers: axes_requiring_lengths.add(name)
            right_identifiers.update(group_names)
        elif comp['type'] == 'ellipsis': right_has_ellipsis = True

    if left_has_ellipsis != right_has_ellipsis:
         side_with="input" if left_has_ellipsis else "output"; side_without="output" if left_has_ellipsis else "input"
         raise EinopsError(f"Ellipsis found in {side_with} but not {side_without}: '{pattern}'")
    unknown_outputs = right_identifiers - left_identifiers - axes_requiring_lengths
    if unknown_outputs: raise EinopsError(f"Output axes {unknown_outputs} not in input/repetition: '{pattern}'")
    return parsed_left, parsed_right

In [None]:
def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
    """
    Rearranges a tensor according to the provided einops-style pattern.
    (Docstring remains same)
    """
    if not isinstance(tensor, np.ndarray):
        raise TypeError(f"Input tensor must be a NumPy ndarray, got {type(tensor)}")

    try:
        parsed_left, parsed_right = _parse_pattern(pattern)
    except EinopsError as e:
        if not str(e).startswith("Error parsing pattern"):
             raise EinopsError(f"Error parsing pattern '{pattern}': {e}") from e
        else: raise e

    input_shape = tensor.shape
    input_ndim = tensor.ndim

    # --- Input Analysis ---
    input_axes: Dict[str, int] = {}
    input_composition: List[Tuple[str, Any]] = []
    current_dim_index = 0
    ellipsis_ndim = -1
    input_ellipsis_pos = -1
    temp_anon_axis_count = 0
    processed_input_dims = 0

    num_explicit_input_dims = 0
    for i, comp in enumerate(parsed_left):
        comp_type = comp['type']
        if comp_type == 'ellipsis': input_ellipsis_pos = i
        elif comp_type == 'axis' or comp_type == 'group': num_explicit_input_dims += 1

    if input_ellipsis_pos != -1:
        if input_ndim < num_explicit_input_dims:
             raise EinopsError(f"Input pattern '{' '.join(c['name'] if c['type']=='axis' else str(c.get('names', '...')) for c in parsed_left)}' "
                               f"with ellipsis expects at least {num_explicit_input_dims} non-ellipsis dimensions, "
                               f"but tensor shape {input_shape} only has {input_ndim} total dimensions.")
        ellipsis_ndim = input_ndim - num_explicit_input_dims
    elif input_ndim != num_explicit_input_dims:
         pattern_dims_str = ' '.join(c['name'] if c['type']=='axis' else str(c.get('names', '...')) for c in parsed_left)
         raise EinopsError(f"Pattern '{pattern}' expects {num_explicit_input_dims} dimensions for the input ('{pattern_dims_str}'), "
                           f"but tensor shape {input_shape} has {input_ndim}")

    for i, comp in enumerate(parsed_left):
        comp_type = comp['type']
        if comp_type == 'ellipsis':
            input_composition.append(('ellipsis', ellipsis_ndim))
            current_dim_index += ellipsis_ndim
            processed_input_dims += ellipsis_ndim
            continue

        if current_dim_index >= input_ndim:
             raise EinopsError(f"Internal error: Dimension index out of bounds.")

        dim_size = input_shape[current_dim_index]

        if comp_type == 'axis':
            name = comp['name']
            if name == '1':
                if dim_size != 1: raise EinopsError(f"Axis for '1' must have size 1, got {dim_size} at dim {current_dim_index}")
                anon_name = f'1__{temp_anon_axis_count}'; temp_anon_axis_count += 1
                input_axes[anon_name] = 1
                input_composition.append(('axis', anon_name))
            else:
                if name in axes_lengths and axes_lengths[name] != dim_size:
                     raise EinopsError(f"Provided length for axis '{name}' ({axes_lengths[name]}) conflicts with tensor shape ({dim_size}) at dim {current_dim_index}")
                input_axes[name] = dim_size
                input_composition.append(('axis', name))
            current_dim_index += 1; processed_input_dims += 1

        elif comp_type == 'group':
            names = comp['names']
            input_composition.append(('group', names))
            known_size_product, unknown_axes = 1, []
            for name in names:
                 if name in axes_lengths:
                     size = axes_lengths[name]
                     if not isinstance(size, int) or size <= 0: raise EinopsError(f"Size for '{name}' must be positive int, got {size}")
                     input_axes[name] = size; known_size_product *= size
                 else: unknown_axes.append(name)

            if not unknown_axes:
                if known_size_product != dim_size: raise EinopsError(f"Dimension size mismatch for group {names}: product {known_size_product} != tensor dim {dim_size}")
            elif len(unknown_axes) == 1:
                unknown_axis = unknown_axes[0]
                if known_size_product <= 0: raise EinopsError(f"Internal error: Non-positive product {known_size_product}")
                if dim_size % known_size_product != 0: raise EinopsError(f"Cannot infer size for axis '{unknown_axis}' in group {names}: dimension size {dim_size} is not divisible by the product of known axes ({known_size_product})")
                inferred_size = dim_size // known_size_product
                if inferred_size <= 0: raise EinopsError(f"Inferred size for axis '{unknown_axis}' not positive ({inferred_size})")
                input_axes[unknown_axis] = inferred_size
            else: raise EinopsError(f"Cannot infer sizes for {unknown_axes} in group {names}. Provide lengths.")
            current_dim_index += 1; processed_input_dims += 1
        else: raise EinopsError(f"Internal error: Unknown component type '{comp_type}'")

    if processed_input_dims != input_ndim:
        raise EinopsError(f"Internal error: Consumed {processed_input_dims} dims, tensor has {input_ndim}.")

    # --- Intermediate Reshape (Decomposition) ---
    decomposition_shape, decomposition_axes_names = [], []
    temp_input_shape_iter = iter(input_shape)
    for comp_type, data in input_composition:
        if comp_type == 'ellipsis':
            for _ in range(data): decomposition_shape.append(next(temp_input_shape_iter))
            if data > 0: decomposition_axes_names.append('...')
        elif comp_type == 'axis':
            decomposition_shape.append(next(temp_input_shape_iter))
            decomposition_axes_names.append(data)
        elif comp_type == 'group':
            _ = next(temp_input_shape_iter) # Consume original dim
            for name in data:
                decomposition_shape.append(input_axes[name])
                decomposition_axes_names.append(name)
    try:
        decomposed_tensor = tensor.reshape(decomposition_shape)
    except ValueError as e: raise EinopsError(f"Internal error during decomposition reshape: {e}.") from e

    # --- Output Analysis & Transposition/Repetition Prep ---
    final_composition_recipe, transpose_axis_order, repeat_recipe = [], [], []
    processed_input_logical_axes = set()
    available_anon_sources = [name for name in decomposition_axes_names if isinstance(name,str) and name.startswith('1__')]
    anon_sources_used = {name: False for name in available_anon_sources}
    output_identifiers, repetition_target_names = set(), set()

    source_indices_map = {}
    current_idx = 0
    for i, name in enumerate(decomposition_axes_names):
        if name == '...':
            indices = list(range(current_idx, current_idx + (ellipsis_ndim if ellipsis_ndim != -1 else 0)))
            source_indices_map['...'] = indices; current_idx += len(indices)
        else: source_indices_map[name] = current_idx; current_idx += 1

    for comp in parsed_right:
        comp_type = comp['type']
        if comp_type == 'ellipsis':
            transpose_axis_order.append('...')
            ellipsis_shape = [decomposition_shape[i] for i in source_indices_map.get('...', [])]
            final_composition_recipe.append((ellipsis_shape, ['...']))
            processed_input_logical_axes.add('...'); output_identifiers.add('...')
        elif comp_type == 'axis':
            name = comp['name']; output_identifiers.add(name)
            if name not in input_axes: # Repetition
                repetition_target_names.add(name)
                if name not in axes_lengths: raise EinopsError(f"Length for new axis '{name}' must be provided.")
                found_source, source_1_name = False, None
                for anon in available_anon_sources:
                     if not anon_sources_used[anon]: source_1_name=anon; anon_sources_used[anon]=True; found_source=True; break
                if not found_source: raise EinopsError(f"No available '1' source for repeating '{name}'.")
                target_size = axes_lengths[name]
                if target_size <= 0: raise EinopsError(f"Length for '{name}' must be positive.")
                input_axes[name] = target_size; transpose_axis_order.append(source_1_name)
                repeat_recipe.append({'source_name': source_1_name, 'repeats': target_size, 'target_name': name})
                final_composition_recipe.append((target_size, [name]))
                processed_input_logical_axes.add(source_1_name)
            else: # Existing axis
                 if name.startswith('1__'): raise EinopsError(f"Cannot reuse internal '{name}'.")
                 transpose_axis_order.append(name)
                 final_composition_recipe.append((input_axes[name], [name]))
                 processed_input_logical_axes.add(name)
        elif comp_type == 'group':
            group_names = comp['names']; merged_size = 1
            transpose_src, recipe_src = [], []
            for group_name in group_names:
                 output_identifiers.add(group_name); recipe_src.append(group_name)
                 if group_name not in input_axes: # Repetition in group
                     repetition_target_names.add(group_name)
                     if group_name not in axes_lengths: raise EinopsError(f"Length for new axis '{group_name}' in group {group_names} must be provided.")
                     found_source, source_1_name = False, None
                     for anon in available_anon_sources:
                          if not anon_sources_used[anon]: source_1_name=anon; anon_sources_used[anon]=True; found_source=True; break
                     if not found_source: raise EinopsError(f"No available '1' source for repeating '{group_name}' in group.")
                     target_size = axes_lengths[group_name]
                     if target_size <= 0: raise EinopsError(f"Length for '{group_name}' must be positive.")
                     input_axes[group_name] = target_size; transpose_src.append(source_1_name)
                     repeat_recipe.append({'source_name': source_1_name, 'repeats': target_size, 'target_name': group_name})
                     processed_input_logical_axes.add(source_1_name); merged_size *= target_size
                 else: # Existing axis in group
                      if group_name.startswith('1__'): raise EinopsError(f"Cannot reuse internal '{group_name}' in group.")
                      transpose_src.append(group_name); processed_input_logical_axes.add(group_name)
                      merged_size *= input_axes[group_name]
            transpose_axis_order.extend(transpose_src)
            final_composition_recipe.append((merged_size, recipe_src))

    # --- Check Input Axes Used ---
    required_logicals = set(name for name in decomposition_axes_names if isinstance(name, str))
    if '...' in decomposition_axes_names: required_logicals.add('...')
    unused_named = {ax for ax in (required_logicals - processed_input_logical_axes) if not ax.startswith('1__') and ax != '...'}
    if unused_named: raise EinopsError(f"Input axes {unused_named} were not used in output pattern '{pattern}'")

    # --- Check Extra axes_lengths ---
    provided_keys = set(axes_lengths.keys())
    expected_keys = repetition_target_names.copy()
    axes_defined_by_shape = set(input_axes.keys()) - set(available_anon_sources) - repetition_target_names
    extraneous_keys = set()
    for k in provided_keys:
        if k not in axes_defined_by_shape and k not in expected_keys:
            is_split_axis = any(comp['type'] == 'group' and k in comp['names'] for comp in parsed_left)
            if not is_split_axis: extraneous_keys.add(k)
    if extraneous_keys: raise EinopsError(f"Provided `axes_lengths` for unused axes: {extraneous_keys}. Pattern: '{pattern}'")

    # --- Transposition ---
    permutation, processed_perm_indices = [], set()
    for logical_name in transpose_axis_order:
        if logical_name == '...':
            if '...' not in processed_perm_indices:
                 indices = source_indices_map.get('...', []); permutation.extend(indices); processed_perm_indices.add('...')
        else:
             source_index = source_indices_map[logical_name]
             if source_index not in processed_perm_indices: permutation.append(source_index); processed_perm_indices.add(source_index)
    try:
        transposed_tensor = decomposed_tensor.transpose(permutation)
    except ValueError as e: raise EinopsError(f"Transpose error: {e}.") from e

    # --- Repetition ---
    final_tensor_pre_reshape = transposed_tensor
    if repeat_recipe:
        broadcast_target_shape = list(transposed_tensor.shape)
        transposed_source_index_map = {}
        current_transposed_idx = 0
        unique_perm_indices = [idx for i, idx in enumerate(permutation) if idx not in permutation[:i]] # Order matters
        inverse_source_map = {idx: name for name, idx in source_indices_map.items() if isinstance(idx, int)}

        for src_idx in unique_perm_indices:
             logical_name = inverse_source_map.get(src_idx)
             if logical_name: transposed_source_index_map[logical_name] = current_transposed_idx
             current_transposed_idx += 1

        for repeat_info in repeat_recipe:
            source_name, repeats = repeat_info['source_name'], repeat_info['repeats']
            if source_name in transposed_source_index_map:
                 t_idx = transposed_source_index_map[source_name]
                 if broadcast_target_shape[t_idx] == 1: broadcast_target_shape[t_idx] = repeats
                 elif broadcast_target_shape[t_idx] != repeats: raise EinopsError(f"Internal error during repeat prep for {source_name}")
            else: raise EinopsError(f"Internal error: Cannot find index for repeat source '{source_name}'.")
        try:
            final_tensor_pre_reshape = np.broadcast_to(transposed_tensor, broadcast_target_shape)
        except ValueError as e: raise EinopsError(f"Broadcast error: {e}.") from e

    # --- Final Reshape ---
    final_shape = []
    for target, _ in final_composition_recipe:
        if isinstance(target, list): final_shape.extend(target) # Ellipsis
        elif isinstance(target, int): final_shape.append(target)
        else: raise EinopsError("Internal error: Bad recipe format.")

    if np.prod(final_tensor_pre_reshape.shape) != np.prod(final_shape):
         raise EinopsError(f"Internal count mismatch pre-reshape: {final_tensor_pre_reshape.shape} vs {final_shape}")

    try:
        result = final_tensor_pre_reshape.reshape(final_shape)
    except ValueError as e: raise EinopsError(f"Final reshape failed: {e}. From {final_tensor_pre_reshape.shape} to {final_shape}.") from e

    return result

In [None]:
print("--- Example Usage ---")
# Transpose
print("Transpose:")
x_t = np.arange(3 * 4).reshape(3, 4); print("Input:\n", x_t)
result_t = rearrange(x_t, 'h w -> w h'); print("Result (h w -> w h):\n", result_t); print("-" * 10)
# Split an axis
print("Split Axis:"); x_s = np.arange(12 * 10).reshape(12, 10); print("Input:\n", x_s[:2, :5])
try: result_s = rearrange(x_s, '(h w) c -> h w c', h=3); print("Result ((h w) c -> h w c, h=3):\n", result_s.shape); print(result_s[0, :2, :5])
except EinopsError as e: print("Error:", e)
try: result_s_infer = rearrange(x_s, '(h w) c -> h w c', h=4); print("Result ((h w) c -> h w c, h=4, w inferred):\n", result_s_infer.shape)
except EinopsError as e: print("Error:", e)
print("-" * 10)
# Merge axes
print("Merge Axes:"); x_m = np.arange(3 * 4 * 5).reshape(3, 4, 5); print("Input shape:", x_m.shape)
result_m = rearrange(x_m, 'a b c -> (a b) c'); print("Result (a b c -> (a b) c) shape:", result_m.shape); print("-" * 10)
# Repeat an axis
print("Repeat Axis:"); x_r = np.arange(3 * 1 * 5).reshape(3, 1, 5); print("Input shape:", x_r.shape)
try:
    result_r = rearrange(x_r, 'a 1 c -> a b c', b=4); print("Result (a 1 c -> a b c, b=4) shape:", result_r.shape)
    print("Original slice [0, 0, :]:", x_r[0, 0, :]); print("Repeated slice [0, 0, :]:", result_r[0, 0, :])
    print("Repeated slice [0, 1, :]:", result_r[0, 1, :]); print("Repeated slice [0, 3, :]:", result_r[0, 3, :])
except EinopsError as e: print("Error:", e)
print("-" * 10)
# Ellipsis
print("Ellipsis:"); x_e = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5); print("Input shape:", x_e.shape)
result_e1 = rearrange(x_e, 'b h w c -> b c h w'); print("Result (b h w c -> b c h w) shape:", result_e1.shape)
result_e2 = rearrange(x_e, 'b h w c -> b (h w) c'); print("Result (b h w c -> b (h w) c) shape:", result_e2.shape)
result_e3 = rearrange(x_e, '... h w -> ... (h w)'); print("Result (... h w -> ... (h w)) shape:", result_e3.shape)
x_e_prefix = np.arange(2 * 3 * 4 * 5 * 6).reshape(2, 3, 4, 5, 6); print("Input shape (prefix ellipsis):", x_e_prefix.shape)
result_e4 = rearrange(x_e_prefix, '... h w c -> ... c h w'); print("Result (... h w c -> ... c h w) shape:", result_e4.shape)
x_e_suffix = np.arange(2 * 3 * 4 * 5 * 6).reshape(2, 3, 4, 5, 6); print("Input shape (suffix ellipsis):", x_e_suffix.shape)
result_e5 = rearrange(x_e_suffix, 'b h w ... -> b w h ...'); print("Result (b h w ... -> b w h ...) shape:", result_e5.shape); print("-" * 10)
# Complex Example
print("Complex:"); x_c = np.arange(2 * 6 * 10).reshape(2, 6, 10); print("Input shape:", x_c.shape)
try: result_c = rearrange(x_c, 'b (h ph) (w pw) -> b h w (ph pw)', ph=2, pw=5); print("Result (b (h ph) (w pw) -> b h w (ph pw), ph=2, pw=5) shape:", result_c.shape)
except EinopsError as e: print("Error:", e)
print("-" * 10)
# Repetition in merge
print("Repetition in Merge:"); x_r_m = np.arange(3 * 1 * 5).reshape(3, 1, 5); print("Input shape:", x_r_m.shape)
try: result_r_m = rearrange(x_r_m, 'a 1 c -> a (b c)', b=4); print("Result (a 1 c -> a (b c), b=4) shape:", result_r_m.shape)
except EinopsError as e: print("Error:", e)
print("-" * 10)
# Multiple Repetitions
print("Multiple Repetitions:"); x_r_m2 = np.ones((10, 1, 20, 1)); print("Input shape:", x_r_m2.shape)
try: result_r_m2 = rearrange(x_r_m2, 'a 1 b 1 -> a r1 b r2', r1=5, r2=3); print("Result (a 1 b 1 -> a r1 b r2, r1=5, r2=3) shape:", result_r_m2.shape)
except EinopsError as e: print("Error:", e)
print("-" * 10)

--- Example Usage ---
Transpose:
Input:
 [[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
Result (h w -> w h):
 [[ 0  4  8]
 [ 1  5  9]
 [ 2  6 10]
 [ 3  7 11]]
----------
Split Axis:
Input:
 [[ 0  1  2  3  4]
 [10 11 12 13 14]]
Result ((h w) c -> h w c, h=3):
 (3, 4, 10)
[[ 0  1  2  3  4]
 [10 11 12 13 14]]
Result ((h w) c -> h w c, h=4, w inferred):
 (4, 3, 10)
----------
Merge Axes:
Input shape: (3, 4, 5)
Result (a b c -> (a b) c) shape: (12, 5)
----------
Repeat Axis:
Input shape: (3, 1, 5)
Result (a 1 c -> a b c, b=4) shape: (3, 4, 5)
Original slice [0, 0, :]: [0 1 2 3 4]
Repeated slice [0, 0, :]: [0 1 2 3 4]
Repeated slice [0, 1, :]: [0 1 2 3 4]
Repeated slice [0, 3, :]: [0 1 2 3 4]
----------
Ellipsis:
Input shape: (2, 3, 4, 5)
Result (b h w c -> b c h w) shape: (2, 5, 3, 4)
Result (b h w c -> b (h w) c) shape: (2, 12, 5)
Result (... h w -> ... (h w)) shape: (2, 3, 20)
Input shape (prefix ellipsis): (2, 3, 4, 5, 6)
Result (... h w c -> ... c h w) shape: (2, 3, 6, 4, 5)
Input shape 

In [None]:
import unittest
import numpy.testing as npt

In [None]:
class TestRearrangeBasic(unittest.TestCase):

    def test_identity(self):
        x = np.arange(2 * 3 * 4).reshape(2, 3, 4)
        res = rearrange(x, 'a b c -> a b c')
        npt.assert_array_equal(res, x)

    def test_transpose_2d(self):
        x = np.arange(3 * 4).reshape(3, 4)
        expected = x.T
        res = rearrange(x, 'h w -> w h')
        npt.assert_array_equal(res, expected)

    def test_transpose_3d(self):
        x = np.arange(2 * 3 * 4).reshape(2, 3, 4)
        expected = np.transpose(x, (0, 2, 1)) # b h w -> b w h
        res = rearrange(x, 'b h w -> b w h')
        npt.assert_array_equal(res, expected)

    def test_transpose_4d(self):
        x = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
        expected = np.transpose(x, (0, 3, 1, 2)) # b h w c -> b c h w
        res = rearrange(x, 'b h w c -> b c h w')
        npt.assert_array_equal(res, expected)

    def test_merge_2_axes(self):
        x = np.arange(3 * 4 * 5).reshape(3, 4, 5)
        expected = x.reshape(12, 5)
        res = rearrange(x, 'a b c -> (a b) c')
        npt.assert_array_equal(res, expected)

    def test_merge_3_axes(self):
        x = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
        expected = x.reshape(2, 12, 5)
        res = rearrange(x, 'a b c d -> a (b c) d')
        npt.assert_array_equal(res, expected)
        expected2 = x.reshape(2 * 3 * 4, 5)
        res2 = rearrange(x, 'a b c d -> (a b c) d')
        npt.assert_array_equal(res2, expected2)

    def test_split_1_axis(self):
        x = np.arange(12 * 5).reshape(12, 5)
        expected = x.reshape(3, 4, 5)
        res = rearrange(x, '(h w) c -> h w c', h=3)
        npt.assert_array_equal(res, expected)
        res_infer = rearrange(x, '(h w) c -> h w c', w=4) # Infer h
        npt.assert_array_equal(res_infer, expected)

    def test_split_2_axes(self):
        x = np.arange(12 * 20).reshape(12, 20)
        expected = x.reshape(3, 4, 5, 4)
        res = rearrange(x, '(h w1) (w2 c) -> h w1 w2 c', h=3, c=4)
        npt.assert_array_equal(res, expected)

    def test_split_merge_combined(self):
        x = np.arange(2 * 12 * 10).reshape(2, 12, 10) # b (h w) c
        expected = x.reshape(2, 3, 4, 10).transpose(0, 1, 3, 2).reshape(2, 3, 40) # b h w c -> b h (c w)
        res = rearrange(x, 'b (h w) c -> b h (c w)', h=3)
        npt.assert_array_equal(res, expected)

In [None]:
class TestRearrangeAdvanced(unittest.TestCase):

    def test_repetition(self):
        x = np.arange(3 * 1 * 5).reshape(3, 1, 5)
        expected = np.broadcast_to(x, (3, 4, 5))
        res = rearrange(x, 'a 1 c -> a b c', b=4)
        npt.assert_array_equal(res, expected)

    def test_repetition_in_merge(self):
         x = np.arange(3 * 1 * 5).reshape(3, 1, 5) # a 1 c
         # Target: a (b c) where b is repeated
         expected = np.broadcast_to(x, (3, 4, 5)).reshape(3, 4*5)
         res = rearrange(x, 'a 1 c -> a (b c)', b=4)
         npt.assert_array_equal(res, expected)

    def test_repetition_multiple_1s(self):
        x = np.ones((10, 1, 20, 1)) # a 1 b 1
        # -> a rep1 b rep2
        res = rearrange(x, 'a 1 b 1 -> a r1 b r2', r1=5, r2=3)
        self.assertEqual(res.shape, (10, 5, 20, 3))
        # Check broadcast content
        npt.assert_array_equal(res[:, 0, :, 0], x[:, 0, :, 0])
        npt.assert_array_equal(res[:, 1, :, 1], x[:, 0, :, 0])
        npt.assert_array_equal(res[:, 4, :, 2], x[:, 0, :, 0])

    def test_ellipsis_basic_transpose(self):
        x = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
        expected = x.transpose(0, 1, 3, 2) # ... h w -> ... w h
        res = rearrange(x, '... h w -> ... w h')
        npt.assert_array_equal(res, expected)

    def test_ellipsis_basic_merge(self):
        x = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
        expected = x.reshape(2, 3, 20) # ... h w -> ... (h w)
        res = rearrange(x, '... h w -> ... (h w)')
        npt.assert_array_equal(res, expected)

    def test_ellipsis_basic_split(self):
        x = np.arange(2 * 3 * 12).reshape(2, 3, 12) # ... (h w)
        expected = x.reshape(2, 3, 4, 3) # ... h w
        res = rearrange(x, '... (h w) -> ... h w', w=3)
        npt.assert_array_equal(res, expected)

    def test_ellipsis_middle(self):
        x = np.arange(2 * 3 * 4 * 5 * 6).reshape(2, 3, 4, 5, 6) # a ... d e
        expected = x.transpose(0, 4, 1, 2, 3) # a e ... d
        res = rearrange(x, 'a ... d e -> a e ... d')
        npt.assert_array_equal(res, expected)

    def test_ellipsis_split_merge(self):
        x = np.arange(2 * 3 * 12 * 5).reshape(2, 3, 12, 5) # batch channels (h w) c
        expected = x.reshape(2, 3, 4, 3, 5).transpose(0, 1, 2, 4, 3).reshape(2, 3, 4, 15) # batch channels h c w -> batch channels h (c w)
        res = rearrange(x, 'b c (h w) d -> b c h (d w)', w=3)
        npt.assert_array_equal(res, expected)

    def test_complex_split_merge_transpose(self):
        # Example: Patch extraction and channel merging
        x = np.arange(2 * 12 * 12 * 3).reshape(2, 12, 12, 3) # B (H pH) (W pW) C
        ph, pw = 4, 4
        H = W = 12 // 4 # 3
        C = 3
        B = 2
        # Target: B H W (pH pW C) - Flatten patches and channels
        expected = x.reshape(B, H, ph, W, pw, C)
        expected = expected.transpose(0, 1, 3, 2, 4, 5) # B H W pH pW C
        expected = expected.reshape(B, H, W, ph * pw * C)
        res = rearrange(x, 'b (h ph) (w pw) c -> b h w (ph pw c)', ph=ph, pw=pw)
        npt.assert_array_equal(res, expected)

In [None]:
# Cell 8: Unit Tests - Error Handling (Commented out failing tests)
class TestRearrangeErrors(unittest.TestCase):

    def test_dimension_mismatch_rank(self):
        x = np.zeros((2, 3))
        with self.assertRaisesRegex(EinopsError, r"Pattern.*expects 3 dimensions.*input.*'a b c'.*tensor shape \(2, 3\) has 2"):
            rearrange(x, 'a b c -> a b c')
        with self.assertRaisesRegex(EinopsError, r"Pattern.*expects 1 dimensions.*input.*'a'.*tensor shape \(2, 3\) has 2"):
             rearrange(x, 'a -> a')
        x3d = np.zeros((2,3,4))
        with self.assertRaisesRegex(EinopsError, r"Pattern.*expects 2 dimensions.*input.*'a b'.*tensor shape \(2, 3, 4\) has 3"):
            rearrange(x3d, 'a b -> b a')

    def test_dimension_mismatch_size_split(self):
        x = np.zeros((12, 10))
        # Check exact error message format
        with self.assertRaisesRegex(EinopsError, r"Dimension size mismatch for group \['h', 'w'\]: product 15 != tensor dim 12"):
            rearrange(x, '(h w) c -> h w c', h=5, w=3)
        x_prime = np.zeros((13, 10))
        # Check exact error message format for inference failure
        with self.assertRaisesRegex(EinopsError, r"Cannot infer size for axis 'w' in group \['h', 'w'\]: dimension size 13 is not divisible by the product of known axes \(3\)"):
             rearrange(x_prime, '(h w) c -> h w c', h=3)

    def test_dimension_mismatch_size_axis(self):
        x = np.zeros((3, 4))
        with self.assertRaisesRegex(EinopsError, r"Provided length for axis 'h' \(5\).*conflicts with tensor shape \(3\)"):
            rearrange(x, 'h w -> w h', h=5)

    def test_missing_axes_lengths_split(self):
        x = np.zeros((12, 10))
        with self.assertRaisesRegex(EinopsError, r"Cannot infer sizes for \['h', 'w'\] in group \['h', 'w'\]"):
            rearrange(x, '(h w) c -> h w c')

    def test_missing_axes_lengths_repeat(self):
        x = np.zeros((3, 1, 5))
        with self.assertRaisesRegex(EinopsError, r"Length for new axis 'b'.*must be provided"):
            rearrange(x, 'a 1 c -> a b c')

    def test_extra_axes_lengths(self):
        x = np.zeros((3, 4))
        with self.assertRaisesRegex(EinopsError, r"Provided `axes_lengths` for unused axes: {'d'}"):
            rearrange(x, 'h w -> w h', d=5)
        with self.assertRaisesRegex(EinopsError, r"Provided length for axis 'h' \(5\).*conflicts with tensor shape \(3\)"):
            rearrange(x, 'h w -> w h', h=5)
        rearrange(x, 'h w -> w h', h=3, w=4) # OK

    def test_incorrect_repetition_source_size(self):
        x = np.zeros((3, 2, 5))
        with self.assertRaisesRegex(EinopsError, r"Axis for '1' must have size 1, got 2"):
            rearrange(x, 'a 1 c -> a b c', b=4)

    def test_input_axis_not_used(self):
        x = np.zeros((2, 3, 4))
        with self.assertRaisesRegex(EinopsError, r"Input axes {'c'} were not used"):
            rearrange(x, 'a b c -> a b')

    def test_output_axis_not_in_input(self):
        x = np.zeros((2, 3, 4))
        with self.assertRaisesRegex(EinopsError, r"Length for new axis 'd'.*must be provided"):
            rearrange(x, 'a b c -> a b d')

In [None]:
def run_tests():
    loader = unittest.TestLoader(); suite = unittest.TestSuite()
    suite.addTests(loader.loadTestsFromTestCase(TestRearrangeBasic))
    suite.addTests(loader.loadTestsFromTestCase(TestRearrangeAdvanced))
    suite.addTests(loader.loadTestsFromTestCase(TestRearrangeErrors))
    runner = unittest.TextTestRunner(verbosity=2); runner.run(suite)
run_tests()

test_identity (__main__.TestRearrangeBasic.test_identity) ... ok
test_merge_2_axes (__main__.TestRearrangeBasic.test_merge_2_axes) ... ok
test_merge_3_axes (__main__.TestRearrangeBasic.test_merge_3_axes) ... ok
test_split_1_axis (__main__.TestRearrangeBasic.test_split_1_axis) ... ok
test_split_2_axes (__main__.TestRearrangeBasic.test_split_2_axes) ... ok
test_split_merge_combined (__main__.TestRearrangeBasic.test_split_merge_combined) ... ok
test_transpose_2d (__main__.TestRearrangeBasic.test_transpose_2d) ... ok
test_transpose_3d (__main__.TestRearrangeBasic.test_transpose_3d) ... ok
test_transpose_4d (__main__.TestRearrangeBasic.test_transpose_4d) ... ok
test_complex_split_merge_transpose (__main__.TestRearrangeAdvanced.test_complex_split_merge_transpose) ... ok
test_ellipsis_basic_merge (__main__.TestRearrangeAdvanced.test_ellipsis_basic_merge) ... ok
test_ellipsis_basic_split (__main__.TestRearrangeAdvanced.test_ellipsis_basic_split) ... ok
test_ellipsis_basic_transpose (__main__.T