<a href="https://colab.research.google.com/github/shlomi1993/deep-learning-notebooks/blob/main/2_broadcast_tensors.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Broadcast Tensors

## Imports and Utilities

In [None]:
from typing import Tuple
from torch import Tensor, tensor, zeros, cat, broadcast_tensors as reference_broadcast_tensors

In [None]:
EXPANSION_ERROR = "Cannot expand tensor of shape {} to shape {}"

class ExpansionError(Exception):
    pass

class BroadcastError(Exception):
    pass

## Expand As

The `expand_as` function expands a given tensor `A` to match the shape of another tensor `B`. If A has fewer dimensions than `B`, it prepends singleton dimensions (size 1) to `A` and reshapes it accordingly. Then, for each dimension where `A` has size 1 and `B` has a larger size, it replicates A using `torch.cat` to match `B`'s size in that dimension. If `A` initially has more dimensions than `B` or if the final shape doesn't match `B`'s shape, the function raises an ExpansionError. This is useful for preparing tensors for broadcasting in element-wise operations.

In [None]:
def expand_as(A: Tensor, B: Tensor) -> Tensor:
    """
    Expands tensor A to match the shape of tensor B by adding leading singleton dimensions and replicating values along
    dimensions where necessary.

    Args:
        A (Tensor): The tensor to be expanded.
        B (Tensor): The target tensor whose shape A should match.

    Returns:
        Tensor: A new tensor with the same shape as B, obtained by expanding A.

    Raises:
        ExpansionError: If A has more dimensions than B or if expansion fails.

    Example:
        >>> A = torch.tensor([1, 2, 3])  # Shape: (3,)
        >>> B = torch.zeros((2, 3))      # Shape: (2, 3)
        >>> expanded_A = expand_as(A, B) # Shape: (2, 3)
    """
    # If A has more dimensions than B, raise an error because it's not possible to expand
    if len(A.shape) > len(B.shape):
        raise ExpansionError(EXPANSION_ERROR.format(A.shape, B.shape))

    # Add leading 1s to A's shape to match the number of dimensions in B
    new_shape_list = [1] * (len(B.shape) - len(A.shape)) + list(A.shape)

    # Clone A to avoid modifying the original tensor and reshape A to match the new shape with leading 1s.
    expanded_A = A.clone().reshape(*new_shape_list)

    # For each dimension of B, expand A if necessary by creating multiple copies along the dimension using cat
    for i in range(len(B.shape)):
        if expanded_A.shape[i] == 1 and B.shape[i] != 1:
            expanded_A = cat([expanded_A] * B.shape[i], dim=i)

    # Raise an error if the resulting tensor shape doesn't matches B's shape
    if expanded_A.shape != B.shape:
        raise ExpansionError(EXPANSION_ERROR.format(A.shape, B.shape))

    # Return the expanded tensor.
    return expanded_A

### Testing `expand_as`

In [None]:
A = tensor([[1], [2], [3]])  # Shape (3, 1)
B = zeros((3, 4))            # Shape (3, 4)
assert A.shape != B.shape, "Test case 1 pre-condition failed"
expanded_A = expand_as(A, B)
assert expanded_A.shape == B.shape, "Test case 1 failed"

A = tensor([[1], [2], [3]])  # Shape (3, 1)
B = zeros((3, 4))            # Shape (3, 4)
A_BACKUP = A.clone()
B_BACKUP = B.clone()
_ = expand_as(A, B)
assert A.equal(A_BACKUP) and B.equal(B_BACKUP), "Test case 2 failed"

A = tensor([1])              # Shape (1, )
B = zeros((2, 3, 4))         # Shape (2, 3, 4)
expanded_A = expand_as(A, B)
assert expanded_A.shape == B.shape, "Test case 3 failed"

A = tensor([1, 2, 3])        # Shape (3, )
B = zeros((3, 4))            # Shape (3, 4)
try:
    _ = expand_as(A, B)
    assert False, "Test case 4 failed: Expected ExpansionError"
except ExpansionError:
    pass

A = tensor([1, 2])       # Shape (2, )
B = zeros((3, 4))        # Shape (3, 4)
try:
    _ = expand_as(A, B)
    assert False, "Test case 5 failed: Expected ExpansionError"
except ExpansionError:
    pass

print("All tests passed!")

All tests passed!


## Are Broadcastable

The `is_tensor_empty` function checks whether a tensor `t` is either `None` or contains zero elements.  
The `are_broadcastable` function determines if two tensors `A` and `B` can be broadcasted together according to standard broadcasting rules. It first ensures neither tensor is empty using `is_tensor_empty`. Then, it compares their shapes from the trailing dimensions, aligning the shorter shape (`A`) with the longer one (`B`). It checks if each pair of dimensions is either equal or if one of them is `1` (which allows broadcasting). If all such checks pass, it returns `True` along with the resulting broadcasted shape; otherwise, it returns `False` and `None`.

In [None]:
def is_tensor_empty(t: Tensor) -> bool:
    """
    Checks whether a given tensor is None or empty.

    Args:
        t (Tensor): The tensor to check.

    Returns:
        bool: True if the tensor is None or contains zero elements, False otherwise.
    """
    return t is None or t.numel() == 0

def are_broadcastable(A: Tensor, B: Tensor) -> Tuple[bool, Tuple[int]]:
    """
    Determines whether two tensors can be broadcasted together according to broadcasting rules.

    Args:
        A (Tensor): The first tensor.
        B (Tensor): The second tensor.

    Returns:
        Tuple[bool, Tuple[int]]: A tuple containing:
            - A boolean indicating whether broadcasting is possible.
            - The resulting broadcast shape if broadcasting is possible, otherwise None.
    """
    # Ensure no empty tensor was provided
    if is_tensor_empty(A) or is_tensor_empty(B):
        return False, None

    # Check tensor shapes and ensure A_shape is shorter
    A_shape, B_shape = sorted([A.shape, B.shape], key=len)

    # Start with the larger tensor's shape.
    broadcast_shape = list(B_shape)

    # Iterate over dimensions from the last (rightmost) to the first (leftmost).
    for i in range(1, len(A_shape) + 1):

        # If dimensions match or A_shape has 1, use B_shape's dimension.
        if A_shape[-i] == B_shape[-i] or A_shape[-i] == 1:
            pass  # Same as broadcast_shape[-i] = B_shape[-i]

        # If B_shape has 1, use A_shape's dimension.
        elif B_shape[-i] == 1:
            broadcast_shape[-i] = A_shape[-i]

        # If dimensions are incompatible, broadcasting is not possible.
        else:
            return False, None

    # Return success and the resulting broadcast shape.
    return True, tuple(broadcast_shape)

### Testing `are_broadcastable`

In [None]:
assert are_broadcastable(tensor([1]), zeros((2, 3, 4))) == (True, (2, 3, 4)), "Test case 1 failed"
assert are_broadcastable(tensor([[1], [2], [3]]), zeros((3, 4))) == (True, (3, 4)), "Test case 2 failed"
assert are_broadcastable(tensor([1, 2]), zeros((3, 4))) == (False, None), "Test case 3 failed"
assert are_broadcastable(tensor([[1, 2, 3]]), zeros((3,))) == (True, (1, 3)), "Test case 4 failed"
assert are_broadcastable(tensor([[1], [2]]), tensor([[[1, 2]], [[3, 4]], [[5, 6]]])) == (True, (3, 2, 2)), "Test case 4 failed"
print("All tests passed!")

All tests passed!


## Broadcast Tensors

The `broadcast_tensors` function attempts to broadcast two tensors `A` and `B` to a common shape. It first checks if broadcasting is possible using the `are_broadcastable` function. If the tensors cannot be broadcasted together, it raises a `BroadcastError`. Otherwise, it creates a zero tensor with the resulting broadcast shape and uses `expand_as` to expand both `A` and `B` to that shape. The function then returns the two expanded tensors as a tuple.

In [None]:
def broadcast_tensors(A: Tensor, B: Tensor) -> Tuple[Tensor]:
    """
    Broadcasts two tensors to a common shape if possible.

    Args:
        A (Tensor): The first tensor.
        B (Tensor): The second tensor.

    Returns:
        Tuple[Tensor, Tensor]: A tuple containing the two tensors expanded to the broadcasted shape.

    Raises:
        BroadcastError: If the tensors cannot be broadcasted together.
    """
    broadcastable, result_shape = are_broadcastable(A, B)
    if not broadcastable:
        raise BroadcastError("Tensors cannot be broadcasted together")

    # Assume the given tensors are broadcastable and use expand_as expand each given tensor
    expanded_A = expand_as(A, zeros(result_shape))
    expanded_B = expand_as(B, zeros(result_shape))
    return expanded_A, expanded_B

### Testing `broadcast_tensors`

In [None]:
A = tensor([1])              # Shape (1,)
B = zeros((2, 3, 4))         # Shape (2,3,4)
expanded_A, expanded_B = broadcast_tensors(A, B)
assert expanded_A.shape == expanded_B.shape == (2, 3, 4), "Test case 1 failed"

A = tensor([[1], [2], [3]])  # Shape (3,1)
B = zeros((3, 4))            # Shape (3,4)
expanded_A, expanded_B = broadcast_tensors(A, B)
assert expanded_A.shape == expanded_B.shape == (3, 4), "Test case 2 failed"

try:
    A = tensor([1, 2])       # Shape (2,)
    B = zeros((3, 4))        # Shape (3,4)
    broadcast_tensors(A, B)  # Should raise error
    assert False, "Test case 3 failed: Expected ValueError"
except BroadcastError:
    pass

print("All tests passed!")

All tests passed!


## Compare to a Reference

The `test_cases` list contains a variety of tensor pairs (`A`, `B`) along with descriptions, designed to test different broadcasting scenarios such as scalar expansion, column/row vector alignment, and multi-dimensional expansion.


In [None]:
test_cases = [
    (tensor([1]), zeros((2, 3, 4)), "Broadcasting scalar to (2,3,4)"),
    (tensor([[1], [2], [3]]), zeros((3, 4)), "Column vector to (3,4)"),
    (tensor([1, 2, 3]), zeros((3, 1)), "Row vector to (3,1)"),
    (tensor([[1, 2, 3]]), zeros((1, 3)), "1-row matrix to (1,3)"),
    (tensor([1, 1, 1]), zeros((3, 3)), "1D vector broadcast to (3,3)"),
    (tensor([[1], [2]]), zeros((2, 3)), "Column vector to (2,3)"),
    (tensor([[1, 2, 3]]), zeros((2, 3)), "Row vector to (2,3)"),
    (tensor([[[1]], [[2]]]), zeros((2, 3, 4)), "(2,1,1) to (2,3,4)"),
    (tensor([1, 2, 3, 4]), zeros((1, 4)), "Vector to (1,4)"),
    (tensor([[1, 2], [3, 4]]), zeros((2, 2)), "No broadcasting needed"),
    (tensor([[1], [2], [3]]), zeros((1, 3, 4)), "Expanding from (3,1) to (1,3,4)"),
    (tensor([1, 2, 3]).view(3, 1, 1), zeros((3, 1, 2)), "Broadcasting (3,) to (3,1,2)"),
    (tensor([[1, 2, 3], [4, 5, 6]]).view(2, 3, 1), zeros((2, 3, 4)), "Expanding to third dim"),
    (tensor([1, 2, 3, 4]), zeros((3, 1, 4)), "Matching last dimension"),
    (tensor([[[1, 2, 3]]]), zeros((2, 3, 3)), "Expanding singleton dimension"),
]

In the testing loop, each pair is passed to the `broadcast_tensors` function, which is compared against the output of a reference implementation (`reference_broadcast_tensors`). Assertions ensure that the expanded tensors match the expected results. If a mismatch or exception occurs, the test is marked as failed and reported. At the end, a summary is printed indicating whether all tests passed or how many failed.

In [None]:
n_fail = 0
for A, B, description in test_cases:
    try:
        expanded_A, expanded_B = broadcast_tensors(A, B)
        expected_A, expected_B = reference_broadcast_tensors(A, B)
        assert expanded_A.equal(expected_A), "Mismatch in broadcasted A"
        assert expanded_B.equal(expected_B), "Mismatch in broadcasted B"
    except:
        print(f"Test '{description}' failed!")
        n_fail += 1
print("All tests passed!" if n_fail == 0 else f"{n_fail} tests failed!")

All tests passed!
