# Implement rearrange function from einops library from scratch


### Setup

Loading Libraries

In [None]:
import re
import numpy as np
from math import prod
import json
import torch

Some initial utility functions

In [None]:
def _tokenize(pattern_str):
    """
    Tokenize a pattern string into a list of items:
    - '...' (ellipsis)
    - '(...)' (group)
    - '\w+' (single axis, e.g. 'b', 'h', 'h1')
    """

    tokens = re.findall(r'\.\.\.|\([\w\s\-]+\)|-?\w+', pattern_str)
    return tokens

def to_numpy_array(input_data):
    """
    Converts the input data to a NumPy array.
    Supports:
        - list
        - NumPy array (returned as-is)
        - PyTorch tensor (converted to NumPy array)
    Raises:
        - TypeError if input data is of an unsupported type.
    """
    if isinstance(input_data, np.ndarray):
        return input_data
    elif isinstance(input_data, list):
        return np.array(input_data)
    elif isinstance(input_data, torch.Tensor):
        return input_data.numpy()
    else:
        raise TypeError(f"Unsupported input type: {type(input_data)}. Expected list, NumPy array, or PyTorch tensor.")

def clean_singletons_in_parentheses(pattern):
    """
    Cleans up singletons in parentheses in einops rearrange patterns.

    Parameters:
        pattern (str): The rearrange pattern to clean.

    Returns:
        str: The cleaned pattern.

    For example,
    Input: b (c 1 1) h w -> b c h w
    Output:  b c h w -> b c h w
    """

    # Define a regex to match parentheses containing dimensions
    paren_regex = re.compile(r'\(([^()]+)\)')

    def replace_function(match):
        # Extract inside of parentheses
        content = match.group(1)
        # Remove `1` from the dimensions
        cleaned_content = ' '.join([dim for dim in content.split() if dim != '1'])
        # If only one dimension remains, remove parentheses
        return cleaned_content if ' ' not in cleaned_content else f'({cleaned_content})'

    # Replace all parentheses content using the defined function
    cleaned_pattern = paren_regex.sub(replace_function, pattern)

    # Remove redundant spaces and ensure clean formatting
    cleaned_pattern = re.sub(r'\s+', ' ', cleaned_pattern).strip()

    return cleaned_pattern

def unexpected_chars_checker(s):
    """
    Check for unexpected characters in the pattern.

    Returns:
        bool: True if no unexpected characters are found, raises a ValueError otherwise.

    Raises:
        ValueError: If unexpected characters are found in the pattern.
    """

    # Define the valid pattern for allowed tokens
    valid_pattern = re.compile(r'^([a-zA-Z]+[1-9][0-9]*|[a-zA-Z]+|\.\.\.|1)$')  # Matches h1, h11, a, b, ..., or 1

    # Remove '->' and parentheses for simpler parsing
    tokens = s.replace('->', '').translate(str.maketrans('', '', '()')).split()

    # Check each token against the valid pattern
    unexpected_chars = [token for token in tokens if not valid_pattern.match(token)]

    if unexpected_chars:
        raise ValueError(f"Unexpected characters found in pattern: {unexpected_chars}")

    # return True  # No unexpected characters found

### Pattern Parsing

In [None]:
class Validator:
    def __init__(self, array, pattern, **kwargs):
        self.array = to_numpy_array(array)
        self._is_empty_array()
        unexpected_chars_checker(pattern)
        self.pattern = clean_singletons_in_parentheses(pattern)
        self.kwargs = kwargs
        self.array_shape = self.array.shape

        self.input_str, self.output_str = self._parse_pattern()
        self.input_tokens = _tokenize(self.input_str)
        self.output_tokens = _tokenize(self.output_str)

    def _is_empty_array(self):

        """
        Checks if the given NumPy array is empty.

        Args:
            array (numpy.ndarray): The input array to check.

        Raises:
            ValueError: If the array is empty.
        """
        if self.array.size == 0:
            raise ValueError("The input NumPy array is empty. Please provide a valid array.")

    def _parse_pattern(self):
        """
        Parses the pattern and splits it into input and output parts.

        Returns:
            tuple: A tuple containing the input and output parts of the pattern.
        """

        try:
            input_str, output_str = self.pattern.split('->')
            return input_str.strip(), output_str.strip()

        except ValueError:
            raise ValueError("Pattern must be in the form 'input -> output'")

    def stripped_order(self, d):
        """
        Removes all paranthesis from a list and returns a list

        ["a", "(b c)"] -> ["a", "b", "c"]
        """

        return " ".join(d).replace("(", "").replace(")", "").split(" ")

    def ellipsis_checker(self):
        """
        Checks if more than ellipsis is in input or output pattern.
        Raises ValueError if found.
        """

        if self.input_tokens.count('...') > 1 or self.output_tokens.count('...') > 1:
            raise ValueError("Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor")

    def identifier_match_checker(self):
        """
        Checks for discrepancies between identifiers in the input and output pattern.
        Raises ValeError if there are differences in identifiers or if some identifiers are duplicated.
        """

        input_tokens_stripped = self.stripped_order(self.input_tokens)
        output_tokens_stripped = self.stripped_order(self.output_tokens)

        input_tokens_stripped_filtered = [item for item in input_tokens_stripped if item != '1']
        output_tokens_stripped_filtered = [item for item in output_tokens_stripped if item != '1']

        if len(set(input_tokens_stripped_filtered)) != len(input_tokens_stripped_filtered):
            raise ValueError(f"Input pattern {self.input_str} contains duplicate dimension")

        if len(set(output_tokens_stripped_filtered)) != len(output_tokens_stripped_filtered):
            raise ValueError(f"Input pattern {self.output_str} contains duplicate dimension")

        missing_in_output = set(input_tokens_stripped_filtered) - set(output_tokens_stripped_filtered)
        extra_in_output = set(output_tokens_stripped_filtered) - set(input_tokens_stripped_filtered)

        if missing_in_output:
            raise ValueError(f"Identifiers only on one side of expression (should be on both): {missing_in_output}")
        if extra_in_output:
            raise ValueError(f"Identifiers only on one side of expression (should be on both): {extra_in_output}")

    def input_token_mapper(self):
        """
        Map the input numpy array shape index positions to the input part of the pattern string.

        Input -
            array = np.random.randn(2, 3, 4)
            input_tokens = ["a", "b", "c"]
        Output -
            input_tokens_mapping = {
                'a': 0,
                'b': 1,
                'c': 2
            }
            input_tokens_shape_mapping = {
                'a': 2,
                'b': 3,
                'c': 4
            }


        """

        ellipsis_count = self.input_tokens.count('...')
        non_ellipsis_tokens = [tok for tok in self.input_tokens if tok != '...']

        if ellipsis_count > 1:
            raise ValueError("Pattern can have at most one ellipsis ('...').")

        # Check if token count matches array dimensions (unless ellipsis exists)
        if ellipsis_count == 0 and len(self.input_tokens) != len(self.array_shape):
            raise ValueError(
                f"Number of input tokens ({len(self.input_tokens)}) must match the array dimensions ({len(self.array_shape)}) unless using ellipsis ('...')."
            )
        if ellipsis_count == 1 and len(non_ellipsis_tokens) > len(self.array_shape):
            raise ValueError(
                f"Pattern with ellipsis ('...') must not have more explicit tokens ({len(non_ellipsis_tokens)}) than array dimensions ({len(self.array_shape)})."
            )

       # Assign index positions to tokens
        input_tokens_mapping = {}
        array_shape_indices = list(range(len(self.array_shape)))  # [0, 1, 2, ..., len(shape)-1]
        input_tokens_shape_mapping = {}

        if '...' in self.input_tokens:
            # Handle ellipsis (map it to remaining index positions)
            ellipsis_index = self.input_tokens.index('...')
            ellipsis_dims = len(array_shape_indices) - len(non_ellipsis_tokens)
            if ellipsis_dims < 0:
                raise ValueError(
                    f"Ellipsis ('...') is invalid: not enough array dimensions to map remaining tokens."
                )
            ellipsis_mapping = array_shape_indices[:ellipsis_dims]
            input_tokens_mapping['...'] = ellipsis_mapping
            array_shape_indices = array_shape_indices[ellipsis_dims:]  # Remove mapped indices
            input_tokens_shape_mapping['...'] = ellipsis_mapping

        # Validate and map remaining tokens to index positions

        singleton_count = 0

        for token, index in zip(non_ellipsis_tokens, array_shape_indices):
            if token == '1':
                # Validate that the corresponding dimension is 1
                if self.array_shape[index] != 1:
                    raise ValueError(
                        f"Dimension for token '1' must be 1, but got {self.array_shape[index]} at index {index}."
                    )
                singleton_count += 1
                input_tokens_mapping["singleton_"+ str(singleton_count)] = index
                input_tokens_shape_mapping["singleton_"+ str(singleton_count)] = self.array_shape[index]
            else:
                input_tokens_mapping[token] = index
                input_tokens_shape_mapping[token] = self.array_shape[index]

        self.input_tokens_mapping = input_tokens_mapping
        self.input_tokens_shape_mapping = input_tokens_shape_mapping
        # print("Input Token Mapping: ", self.input_tokens_mapping)
        # print("Input Token Shape Mapping: ", self.input_tokens_shape_mapping)

    def output_tokens_mapper(self):
        """
        Create mapping from output part of the pattern string.

        Input -
            output_tokens = ["a", "b", "c"]
        Output -
            output_tokens_mapping = {
                'a': 0,
                'b': 1,
                'c': 2
            }
        """

        _, output_str = self.pattern.split('->')
        self.output_str = output_str.strip()
        self.output_tokens = _tokenize(output_str)

        # Map output tokens to index positions
        output_tokens_mapping = {}
        current_index = 0

        singleton_count = 0
        for token in self.output_tokens:
            # if token == '...':
            #     output_tokens_mapping[token] = current_index
            #     current_index += 1
            if token == '1':
                singleton_count += 1
                # Fixed dimension (explicitly adds dimension of size 1)
                output_tokens_mapping["singleton_"+str(singleton_count)] = current_index
                current_index += 1
            else:
                # Map single axis
                output_tokens_mapping[token] = current_index
                current_index += 1

        self.output_tokens_mapping = output_tokens_mapping
        # print("Output Token Mapping: ", self.output_tokens_mapping)

    def empty_ellipsis_checker(self):
        """
        Ellipsis in the input can be mapped to an empty list i.e. there are no axes that it points to.
        In cases like this, ellipsis becomes unnecessary to handle. So, this function removes them from the mappings and updates the pattern
        """

        if('...') in list(self.input_tokens_mapping.keys()) and len(self.input_tokens_mapping['...']) == 0:
            del self.input_tokens_mapping['...']
            del self.input_tokens_shape_mapping['...']
            new_pattern = self.pattern.replace('...', '')

            self.pattern = new_pattern.lstrip().rstrip()

    def validate_and_return(self):
        """
        Validates the pattern, tokens, and array shape, then returns mappings.
        """

        self.ellipsis_checker()
        self.identifier_match_checker()
        self.input_token_mapper()
        self.empty_ellipsis_checker()
        self.output_tokens_mapper()

        return (
            self.array,
            self.input_tokens_mapping,
            self.input_tokens_shape_mapping,
            self.output_tokens_mapping,
        )

Utility functions after creating mappings

In [None]:
def tokens_from_paranthesis(input_set):
    """
    Get identifiers from paranthesis
    """

    paranth_set = set()
    for item in input_set:
        # Check if the item contains parentheses
        if '(' in item and ')' in item:
            # Extract elements inside parentheses and split by whitespace
            inner_elements = re.findall(r'\((.*?)\)', item)
            for group in inner_elements:
                paranth_set.update(group.split())

    return paranth_set

def check_extra_arguments(input_mapping, **kwargs):
    """
    Validates that no extra arguments are provided that are not part of the input mapping.

    Args:
        input_mapping (dict): Input token mapping.
        kwargs: Additional arguments provided to the function.

    Raises:
        ValueError: If extra arguments are detected.
    """
    allowed_tokens = tokens_from_paranthesis(set(input_mapping.keys()))
    provided_tokens = set(kwargs.keys())

    # Find extra arguments
    extra_tokens = provided_tokens - allowed_tokens

    if extra_tokens:
        raise ValueError(f"Extra arguments provided: {extra_tokens}. "
                         f"Allowed arguments are: {list(allowed_tokens)}.")

    # print("No extra and unnecessary arguments provided.")

In [None]:
def get_additional_args(input_mapping, shape_mapping, **kwargs):
    """
    Validates parentheses in input token mappings and ensures the arguments
    provided match the required shapes in the shape mapping. Returns all
    arguments, including inferred ones.

    Args:
        input_mapping (dict): Input token mapping.
        shape_mapping (dict): Input shape mapping.
        kwargs: Additional arguments corresponding to tokens inside parentheses.

    Returns:
        dict: A dictionary of all arguments (inferred + original).

    Raises:
        ValueError: If parentheses validation fails.
    """
    inferred_args = {}

    for token, index in input_mapping.items():
        if '(' in token and ')' in token:
            # Extract tokens inside parentheses
            inner_tokens = token.strip('()').split()

            # Retrieve the shape value for the grouped token
            expected_shape = shape_mapping[token]

            # Check provided arguments
            provided_args = [kwargs[arg] for arg in inner_tokens if arg in kwargs]

            if len(provided_args) < len(inner_tokens):
                # Try to infer the other argument
                product = 1
                for arg_value in provided_args:
                    product *= arg_value
                if expected_shape % product != 0:
                    raise ValueError(f"Could not infer sizes for {set(inner_tokens) - set(list(kwargs.keys()))}.")
                else:
                    inferred_value = expected_shape // product
                    inferred_token = [t for t in inner_tokens if t not in kwargs][0]
                    inferred_args[inferred_token] = inferred_value
                    # print(f"Inferred value for {inferred_token}: {inferred_value}")

            elif len(provided_args) == len(inner_tokens):
                # Multiple arguments must exactly match the shape value when multiplied
                product = 1
                for arg_value in provided_args:
                    product *= arg_value
                if product != expected_shape:
                    raise ValueError(f"Product of arguments {inner_tokens} ({provided_args}) does not match "
                                     f"the expected shape {expected_shape} for token {token}.")
            else:
                # No arguments provided, cannot validate
                raise ValueError(f"Missing required arguments for token: {token}. "
                                 f"Expected at least one of {inner_tokens}.")

    # Combine original and inferred arguments
    all_args = {**kwargs, **inferred_args}
    # print("All parentheses checks passed.")
    return all_args

### Input Pattern based Transformations

In [None]:
def input_based_transformation(array, input_mapping, input_shape_mapping, **kwargs):
    """
    Transforms the input array based on input mapping, ellipsis (`...`), and additional parameters for parentheses expansions.

    Parameters:
    - array: np.ndarray, the input array to transform.
    - pattern: str, the input pattern.
    - input_mapping: dict, maps each dimension name in the pattern to its index in the array.
    - input_shape_mapping: dict, maps each dimension name in the pattern to its size in the array.
    - kwargs: additional arguments for dimensions inside parentheses, e.g., h1, h, w1, w.

    Returns:
    - np.ndarray: transformed array.
    """
    original_shape = list(array.shape)
    new_shape = []

    for key, index in input_mapping.items():
        if key == '...':  # Handle ellipsis
            ellipsis_dims = [original_shape[i] for i in index]
            new_shape.extend(ellipsis_dims)
        elif '(' in key and ')' in key:  # Handle parentheses
            # Get the components inside the parentheses
            components = key.strip('()').split()

            # Ensure the components exist in the kwargs
            if not all(dim in kwargs for dim in components):
                raise ValueError(f"Missing dimensions {components} for expanding '{key}'.")

            # Replace the single dimension with the expanded dimensions
            expanded_dims = [kwargs[dim] for dim in components]
            new_shape.extend(expanded_dims)
        else:
            # Add the size of the dimension directly from the input shape
            new_shape.append(original_shape[index])

    # Reshape the array
    transformed_array = array.reshape(*new_shape)
    return transformed_array

Update token mapping based on input transformations

In [None]:
def update_input_tokens_mapping(input_tokens_mapping, **kwargs):
    """
    Expands the input tokens mapping based on additional arguments, while preserving ellipsis (`...`) position.

    Parameters:
    - input_tokens_mapping: dict, maps tokens to indices in the array.
    - kwargs: dict, additional arguments for expanding dimensions (e.g., h, w, etc.).

    Returns:
    - dict: Expanded token mapping with proper indices for all tokens, including ellipsis.
    """
    expanded_tokens_mapping = {}
    current_index = 0  # Track the current index in the mapping

    for token, index in input_tokens_mapping.items():
        if token == '...':  # Handle ellipsis
            if index == []:  # If ellipsis is empty, keep it as is
                expanded_tokens_mapping[token] = index
            else:
                # Preserve the ellipsis with its original indices
                expanded_tokens_mapping[token] = index
                current_index = max(index) + 1 if index else current_index
        elif '(' in token and ')' in token:  # Handle grouped tokens like '(h w)'
            # Extract components inside parentheses
            components = token.strip('()').split()

            # Ensure all components are present in kwargs
            if not all(dim in kwargs for dim in components):
                raise ValueError(f"Missing dimensions {components} for expanding '{token}'.")

            # Expand the components and assign sequential indices
            for dim in components:
                expanded_tokens_mapping[dim] = current_index
                current_index += 1
        else:
            # Assign sequential indices for direct tokens
            expanded_tokens_mapping[token] = current_index
            current_index += 1

    return expanded_tokens_mapping

### Output Pattern-based Transformations

Class defined for all output pattern related transformations

In [None]:
from math import prod

class Output_Transformations:
    def __init__(self, array, token_mapping, output_mapping):
        self.array = array
        self.token_mapping = token_mapping
        self.output_mapping = output_mapping

        self.output_order = list(self.output_mapping.keys())
        self.input_order = list(self.token_mapping.keys())

        singleton_values = [
                int(key.split("_")[1])  # Extract and convert the number part to an integer
                for key in list(self.token_mapping.keys())
                if key.startswith("singleton_")  # Check if the key starts with "singleton_"
            ]

        self.last_singleton_value = max(singleton_values, default=0)

        self.last_index_val = max(
                    max(val) if isinstance(val, list) and len(val) > 0 else val
                    for val in self.token_mapping.values()
                )

        # print("Last Singleton Value: ", self.last_singleton_value)

    def remove_singleton(self):
        """
        If singleton is in input mapping, but is not needed in output mapping, then remove it from input mapping and array.
        """

        remove_arr = []

        for token in self.input_order:
            if "singleton" in token:
                if token not in self.stripped_order(self.output_order):
                    remove_arr.append(self.token_mapping[token])
                    del self.token_mapping[token]

        if len(remove_arr) > 0:
            # print("Removing singletons")
            # print(remove_arr)
            for count, key in enumerate(self.token_mapping.keys()):
                self.token_mapping[key] = count
            # print(self.token_mapping)
            self.array = self.array.squeeze(axis=tuple(remove_arr))

    def stripped_order(self, d):
        return " ".join(d).replace("(", "").replace(")", "").split(" ")

    def add_singleton(self):
        """
        If singleton is in output mapping, but is not in input mapping, then add it to output array shape.
        """

        count = 0
        for token in self.stripped_order(self.output_order):
            if "singleton" in token:
                if token not in self.stripped_order(self.input_order):
                    count += 1
                    self.token_mapping["singleton_"+str(self.last_singleton_value+count)] = self.last_index_val + count

        new_shape = list(self.array.shape)

        for i in range(count):
            new_shape.append(1)

        self.array = self.array.reshape(*new_shape)

    def reorder_array(self):
        """
        Based on the order of output identifiers (after removing paranthesis), re-order the array.
        """

        output_order_stripped = self.stripped_order(self.output_order)
        # print(output_order_stripped)

        reshape_order = []
        for token in output_order_stripped:
            if '...' in token:
                reshape_order += self.token_mapping[token]
            else:
                reshape_order.append(self.token_mapping[token])

        # print(reshape_order)
        self.reshaped_array = self.array.transpose(*reshape_order)

    def resum_array(self):
        """
        Based on if paranthesis exist in output, reshape the array.
        """

        s = self.array.shape

        new_order = []
        for token in self.output_order:
            if "(" in token:
                inner_tokens = token.strip('()').split()
                # print(inner_tokens)
                new_shape = prod([s[self.token_mapping[inner_t]] for inner_t in inner_tokens])
                new_order.append(new_shape)
            elif '...' in token:
                new_order += [s[inner_t] for inner_t in self.token_mapping[token]]
            else:
                new_order.append(s[self.token_mapping[token]])
        # print(new_order)
        self.resummed_array = self.reshaped_array.reshape(*new_order)
        # print("Array resummed: ", self.resummed_array.shape)

    def _requires_reshaping(self):
        """
        Checks to see if paranthesis exist in output.
        """

        paranth_flag = False
        for token in self.output_order:
            if("(" in token):
                paranth_flag = True

        return paranth_flag

    def transform(self):
        """
        1. Calls remove singleton function
        2. Calls add singleton function
        3. Calls reorder array function
        4. Calls resum array function is reshaping is required
        """

        self.remove_singleton()
        self.add_singleton()

        # print(self.array.shape)

        self.reorder_array()

        # print("Array reshaped: ", self.reshaped_array.shape)

        if self._requires_reshaping():
            self.resum_array()   # Step 4: Reshape grouped dimensions
            return self.resummed_array
        else:
            return self.reshaped_array


Final function is ready to be defined

In [None]:
def rearrange(array, pattern, **kwargs):
    """
    Rearranges an array based on the einops-like pattern and additional arguments.

    Steps:
    1. Validate the input array and pattern.
    2. Process extra arguments and parentheses.
    3. Perform input-based transformations.
    4. Update input tokens mapping based on transformations.
    5. Apply output transformations.
    6. Return the final transformed array.
    """

    v = Validator(array, pattern, **kwargs)

    array, input_tokens_mapping, input_tokens_shape_mapping, output_tokens_mapping = v.validate_and_return()

    check_extra_arguments(input_tokens_mapping, **kwargs)
    all_args = get_additional_args(input_tokens_mapping, input_tokens_shape_mapping, **kwargs)

    transformed_array = input_based_transformation(
                                                array,
                                                input_tokens_mapping,
                                                input_tokens_shape_mapping,
                                                **all_args)

    updated_tokens_mapping = update_input_tokens_mapping(
        input_tokens_mapping,
        **all_args,
    )

    d = Output_Transformations(transformed_array, token_mapping=updated_tokens_mapping, output_mapping=output_tokens_mapping)
    return d.transform()


### Unit Tests

In [None]:
import unittest

class TestValidator(unittest.TestCase):
    def test_valid_pattern(self):
        array = np.ones((2, 3, 4))
        pattern = "a b c -> b a c"
        v = Validator(array, pattern)
        self.assertEqual(v.input_tokens, ["a", "b", "c"])
        self.assertEqual(v.output_tokens, ["b", "a", "c"])

    def test_ellipsis_handling(self):
        array = np.ones((2, 3, 4, 5))
        pattern = "... c -> c ..."
        v = Validator(array, pattern)
        v.validate_and_return()
        self.assertIn("...", v.input_tokens_mapping)
        self.assertEqual(len(v.input_tokens_mapping["..."]), 3)

    def test_invalid_pattern(self):
        array = np.ones((2, 3))
        pattern = "a b -> a b c"
        with self.assertRaises(ValueError):
            Validator(array, pattern).validate_and_return()

In [None]:
class TestRearrange(unittest.TestCase):
    def test_basic_rearrange(self):
        array = np.arange(6).reshape(2, 3)  # Shape: (2, 3)
        pattern = "a b -> b a"
        result = rearrange(array, pattern)
        expected = np.array([[0, 3], [1, 4], [2, 5]])  # Transposed
        np.testing.assert_array_equal(result, expected)

    def test_singleton_dimension(self):
        array = np.random.randn(2, 1, 3)  # Shape: (2, 1, 3)
        pattern = "a 1 b -> b a"
        result = rearrange(array, pattern)
        expected = np.random.randn(3, 2)  # Singleton dimension removed
        np.testing.assert_array_equal(result.shape, expected.shape)

    def test_ellipsis(self):
        array = np.ones((2, 3, 4, 5))  # Shape: (2, 3, 4, 5)
        pattern = "... c -> c ..."
        result = rearrange(array, pattern)
        expected = np.ones((5, 2, 3, 4))  # Move the last dimension to the front
        np.testing.assert_array_equal(result, expected)


In [None]:
class TestOutputTransformations(unittest.TestCase):
    def test_remove_singleton(self):
        array = np.ones((2, 1, 3))  # Shape: (2, 1, 3)
        token_mapping = {"a": 0, "singleton_1": 1, "b": 2}
        output_mapping = {"b": 0, "a": 1}
        d = Output_Transformations(array, token_mapping, output_mapping)
        d.remove_singleton()
        self.assertEqual(d.array.shape, (2, 3))
        self.assertNotIn("singleton_1", d.token_mapping)

    def test_add_singleton(self):
        array = np.ones((2, 3))  # Shape: (2, 3)
        token_mapping = {"a": 0, "b": 1}
        output_mapping = {"singleton_1": 0, "a": 1, "b": 2}
        d = Output_Transformations(array, token_mapping, output_mapping)
        o = d.transform()
        self.assertEqual(o.shape, (1, 2, 3))  # Singleton added
        self.assertIn("singleton_1", d.token_mapping)

    def test_reshape_array(self):
        array = np.ones((2, 3, 4))  # Shape: (2, 3, 4)
        token_mapping = {"a": 0, "b": 1, "c": 2}
        output_mapping = {"c": 0, "a": 1, "b": 2}
        d = Output_Transformations(array, token_mapping, output_mapping)
        d.reorder_array()
        self.assertEqual(d.reshaped_array.shape, (4, 2, 3))

class TestRearrangeFunctions(unittest.TestCase):
    def test_varied_cases(self):
        all_tests = [{'array_shape': (2, 3, 12, 6),
        'pattern': 'b h (h1 h2) c -> b c h1 h2 h',
        'args': {'h1': 4},
        'einops_ground_truth': (2, 6, 4, 3, 3)},
        {'array_shape': (2, 3, 12, 6),
        'pattern': 'b h w c -> b w h c',
        'args': {},
        'einops_ground_truth': (2, 12, 3, 6)},
        {'array_shape': (2, 3, 12, 6),
        'pattern': 'b h w c -> (b h) w c',
        'args': {},
        'einops_ground_truth': (6, 12, 6)},
        {'array_shape': (2, 3, 12, 6),
        'pattern': 'b h w c -> b (c h w)',
        'args': {},
        'einops_ground_truth': (2, 216)},
        {'array_shape': (2, 12, 18, 6),
        'pattern': 'b (h1 h) (w1 w) c -> (b h1 w1) h w c',
        'args': {'h1': 3, 'h': 4, 'w1': 3, 'w': 6},
        'einops_ground_truth': (18, 4, 6, 6)},
        {'array_shape': (2, 12, 18, 6),
        'pattern': 'b (h h1) (w w1) c -> b h w (c h1 w1)',
        'args': {'h1': 3, 'w': 6},
        'einops_ground_truth': (2, 4, 6, 54)},
        {'array_shape': (2, 12, 18, 6),
        'pattern': '... h w -> ... (h w)',
        'args': {},
        'einops_ground_truth': (2, 12, 108)},
        {'array_shape': (2, 12, 18, 6),
        'pattern': '... (h w) c -> ... (h w c)',
        'args': {'w': 6},
        'einops_ground_truth': (2, 12, 108)},
        {'array_shape': (2, 3),
        'pattern': '... h w -> ... w h 1',
        'args': {},
        'einops_ground_truth': (3, 2, 1)},
        {'array_shape': (2, 3, 4),
        'pattern': 'b c h -> b h c',
        'args': {},
        'einops_ground_truth': (2, 4, 3)},
        {'array_shape': (2, 3, 4),
        'pattern': 'b c h -> b (c h) 1',
        'args': {},
        'einops_ground_truth': (2, 12, 1)},
        {'array_shape': (2, 3, 4),
        'pattern': 'b c h -> b h c 1 1',
        'args': {},
        'einops_ground_truth': (2, 4, 3, 1, 1)},
        {'array_shape': (2, 3, 4, 1),
        'pattern': 'b c h 1 -> b c h',
        'args': {},
        'einops_ground_truth': (2, 3, 4)}
        ]

        for test in all_tests:
            array = np.random.randn(*test['array_shape'])
            result = rearrange(array, test['pattern'], **test['args'])
            self.assertEqual(result.shape, test['einops_ground_truth'])

    def test_incompatible_patterns(self):
        array = np.random.randn(2, 3, 4)
        patterns = [
            'a b c -> a c',
            'a b 1 -> a b',
            'a (b c) -> a b c',
            'a b (c1 c2) -> a b c1 c2'
        ]

        for pattern in patterns:
            with self.assertRaises(ValueError):
                rearrange(array, pattern)

    def test_incompatible_arguments(self):
        array = np.random.randn(32, 30, 120)
        pattern = 'b h (w1 w2) -> w1 h b w2'
        args = {'w1': 11}

        with self.assertRaises(ValueError):
            rearrange(array, pattern, **args)

        array = np.random.randn(32, 30, 120)
        pattern = 'b h (w1 w2 w3) -> w1 h b w2 w3'
        args = {'w1': 12}

        with self.assertRaises(ValueError):
            rearrange(array, pattern, **args)

In [None]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_add_singleton (__main__.TestOutputTransformations) ... ok
test_remove_singleton (__main__.TestOutputTransformations) ... ok
test_reshape_array (__main__.TestOutputTransformations) ... ok
test_basic_rearrange (__main__.TestRearrange) ... ok
test_ellipsis (__main__.TestRearrange) ... ok
test_singleton_dimension (__main__.TestRearrange) ... ok
test_incompatible_arguments (__main__.TestRearrangeFunctions) ... ok
test_incompatible_patterns (__main__.TestRearrangeFunctions) ... ok
test_varied_cases (__main__.TestRearrangeFunctions) ... ok
test_ellipsis_handling (__main__.TestValidator) ... ok
test_invalid_pattern (__main__.TestValidator) ... ok
test_valid_pattern (__main__.TestValidator) ... ok

----------------------------------------------------------------------
Ran 12 tests in 0.126s

OK


<unittest.main.TestProgram at 0x79077739fdc0>

### Examples of Using Custom Rearrange Function

In [None]:
# Transpose
x = np.random.rand(3, 4)
result = rearrange(x, 'h w -> w h')
print(result.shape)

(4, 3)


In [None]:
# Split an axis
x = np.random.rand(12, 10)
result = rearrange(x, '(h w) c -> h w c', h=3)
result.shape

(3, 4, 10)

In [None]:
# Merge axes
x = np.random.rand(3, 4, 5)
result = rearrange(x, 'a b c -> (a b) c')
result.shape

(12, 5)

In [None]:
# Repeat an axis
x = np.random.rand(3, 1, 5)
result = rearrange(x, 'a b c -> a b 1 c')
result.shape

(3, 1, 1, 5)

In [None]:
# Handle batch dimensions
x = np.random.rand(2, 3, 4, 5)
result = rearrange(x, '... h w -> ... (h w)')
result.shape

(2, 3, 20)