# 22961, Mmn11

Author: Tal Glanzman

In [84]:
import torch
from typing import Sequence, Tuple

In [85]:
def expand(tensor: torch.Tensor, size: torch.Size) -> torch.Tensor:
    """
    Expand a given `tensor`, to a newly allocated tensor, with the shape given by `size`.
    
    The expansion is done by the broadcast rules:
    - As long as the rank of `tensor` is less than `len(size)`
      - Add degenerative axis to the tensor
    - Iterating over every dimesnion `dim`, right to left:
      - If the `tensor.shape[dim] != size[dim]`:
        - If `tensor.shape[dim] == 1`:
          - Concatenate the tensor with itself, `shape[dim]` times, along the `dim`th axis
        - Else: An error is raised indicating that the tensor is not expandable to size `size`

    Arguments
      tensor {torch.Tensor} The tensor to expand
      size {torch.Size} The size to expand the tensor to

    Returns
      {torch.Tensor} A newly allcoated tensor
    """
    ans = tensor.clone()
    rank = len(size)
    
    # Iterate on every (to) dimension
    for i in range(rank):

        # A cursor to the current axis in `size`
        j = rank - 1 - i

        # A cursor to the current dimension in `tensor.shape`.
        # Note that it can be negative at first, but at most -1.
        k = len(ans.shape) - 1 - i

        if k < 0:
            # It is sufficient to only add one degenerate dimension because we
            # add them one by one - so we know that the index will be in bounds.
            ans = ans.unsqueeze(0)
            k = 0

        ans_dimsize = ans.shape[k]
        dimsize = size[j]

        if ans_dimsize != dimsize:
            if ans_dimsize != 1:
                raise Exception(f"Cannot expand tensor with shape {tuple(ans.shape)} to shape {size}")
            ans = torch.cat([ans] * dimsize, dim=k)
        
    return ans

In [86]:
def test_expand():
    """Tester for `are_broadcastable_together`"""
    
    testcases = [
        ((1, ), (3, )),
        ((1, ), (2, 4, 5)),
        ((3, 2), (1, 1, 3, 2)),
        ((1, 3, 2, 1, 1), (1, 3, 2)),
        ((2,), (3,))
    ]
    
    for i, (tensor_shape, size) in enumerate(testcases):
        print(f"====== Test {i} =======")
        tensor = torch.empty(tensor_shape)

        print("tensor_shape", tensor_shape)
        print("size", size)

        torch_error = False
        custom_error = False
        
        try:
            torch_expanded = tensor.expand(size)
        except:
            print("Torch failed to expand")
            torch_error = True

        try:
            custom_expanded = expand(tensor, size)
        except:
            print("Custom implementation failed to expand")
            custom_error = True

        assert torch_error == custom_error, "Only one of Torch and Custom implementations raised an error"

        if not torch_error:
            assert torch.equal(torch_expanded, custom_expanded), "Torch and Custom implementation yielded different tensors"
        
        print("")

test_expand()

tensor_shape (1,)
size (3,)

tensor_shape (1,)
size (2, 4, 5)

tensor_shape (3, 2)
size (1, 1, 3, 2)

tensor_shape (1, 3, 2, 1, 1)
size (1, 3, 2)
Torch failed to expand
Custom implementation failed to expand

tensor_shape (2,)
size (3,)
Torch failed to expand
Custom implementation failed to expand



In [87]:
def are_broadcastable_together(a: torch.Tensor, b: torch.Tensor) -> Tuple[bool, torch.Size]:
    """
    Check whether two {torch.Tensor}s are broadcastable together.
    If they do, the shared size is returned.

    To acquire the shared size we (logically, not algorithmically) following this:
    - While the ranks of `a` and `b` are not the same, add a degenerative dimension
      as dimension 0 for the lower ranked tensor. 
    - For every dimension of `a` and `b`, `adim` and `bdim` respectively
      - If `adim != bdim`
        - If `adim == 1`: Set the corresponding dimension of the result size to `bdim`
        - If `bdim == 1`: Set the corresponding dimension of the result size to `adim`
        - Else: The tensors are not broadcastable together!

    Returns
        A tuple:
            {boolean} - Indicates whether the tensors can be broadcastable together
            {torch.Size} - If the first element is True, this size contains the shared
                           size the tensors can be expanded to.
    """
    shape_a = list(a.shape)
    shape_b = list(b.shape)
    
    rank_a = len(shape_a)
    rank_b = len(shape_b)

    for i in range(max(rank_a, rank_b)):
        axis_a = rank_a - 1 - i
        axis_b = rank_b - 1 - i

        if axis_a < 0:
            shape_a.insert(0, 1)
            axis_a = 0

        if axis_b < 0:
            shape_b.insert(0, 1)
            axis_b = 0

        size_a = shape_a[axis_a]
        size_b = shape_b[axis_b]

        if size_a != size_b:
            if size_a == 1:
                shape_a[axis_a] = size_b
            elif size_b == 1:
                shape_b[axis_b] = size_a
            else:
                return (False, None)

    assert shape_a == shape_b, "Should never happen - It's a bug if it did"
    return (True, tuple(shape_a))

In [88]:
def test_are_broadcastable_together():
    """Tester for `are_broadcastable_together`"""
    
    testcases = [
        ((1,), (1, 1)),

        ((2, 1, 2), (2, 1, 1, 3)),
        ((2, 1, 2), (5, 1)),
        ((2, 1, 2), (1, 5)),
        ((2, 1, 2), (1,)),
        ((2, 1, 2), (5, 3)),

        ((2, 1, 1, 3), (5, 1)),
        ((2, 1, 1, 3), (1, 5)),
        ((2, 1, 1, 3), (1,)),
        ((2, 1, 1, 3), (5, 3)),

        ((5, 1), (1, 5)),
        ((5, 1), (1, )),
        ((5, 1), (5, 3)),

        ((1, 5), (1, )),
        ((1, 5), (5, 3)),
        ((1, ), (5, 3)),
    ]

    for i, (a_shape, b_shape) in enumerate(testcases):
        print(f"====== Test {i} ======")
        print("a_shape", a_shape)
        print("b_shape", b_shape)
        
        a = torch.empty(a_shape)
        b = torch.empty(b_shape)
        
        actual_broadcastable, actual_shape = are_broadcastable_together(a, b)

        if not actual_broadcastable:
            print("Custom implementation couldn't broadcast")
            
        try:
            torch_broadcast = torch.broadcast_tensors(a, b)
        except:
            print("Torch couldn't broadcast\n")
            assert not actual_broadcastable, "Torch couldn't broadcast but custom implemntation did"
            continue

        torch_shape = tuple(torch_broadcast[0].shape)
        print("Actual shape: ", actual_shape)
        print("Torch shape: ", torch_shape)
        
        assert actual_shape == torch_shape, "Torch and Custom implementation broadcasted differently"
        
        print("")

test_are_broadcastable_together()

a_shape (1,)
b_shape (1, 1)
Actual shape:  (1, 1)
Torch shape:  (1, 1)

a_shape (2, 1, 2)
b_shape (2, 1, 1, 3)
Custom implementation couldn't broadcast
Torch couldn't broadcast

a_shape (2, 1, 2)
b_shape (5, 1)
Actual shape:  (2, 5, 2)
Torch shape:  (2, 5, 2)

a_shape (2, 1, 2)
b_shape (1, 5)
Custom implementation couldn't broadcast
Torch couldn't broadcast

a_shape (2, 1, 2)
b_shape (1,)
Actual shape:  (2, 1, 2)
Torch shape:  (2, 1, 2)

a_shape (2, 1, 2)
b_shape (5, 3)
Custom implementation couldn't broadcast
Torch couldn't broadcast

a_shape (2, 1, 1, 3)
b_shape (5, 1)
Actual shape:  (2, 1, 5, 3)
Torch shape:  (2, 1, 5, 3)

a_shape (2, 1, 1, 3)
b_shape (1, 5)
Custom implementation couldn't broadcast
Torch couldn't broadcast

a_shape (2, 1, 1, 3)
b_shape (1,)
Actual shape:  (2, 1, 1, 3)
Torch shape:  (2, 1, 1, 3)

a_shape (2, 1, 1, 3)
b_shape (5, 3)
Actual shape:  (2, 1, 5, 3)
Torch shape:  (2, 1, 5, 3)

a_shape (5, 1)
b_shape (1, 5)
Actual shape:  (5, 5)
Torch shape:  (5, 5)

a_shape

In [89]:
def broadcast_tensors(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Basically check and get the shared size using `are_broadcastable_together` and then
    invoke `expand` on `a` and `b`.

    Returns
        Tuple of:
            {torch.Tensor} the expanded tensor emerged from `a`
            {torch.Tensor} the expanded tensor emerged from `b`
    """
    broadcastable, size = are_broadcastable_together(a, b)
    if not broadcastable:
        raise Exception(
            f'Unable to broadcast tensors of shapes "{a.shape}" and "{b.shape}"')

    return (
        expand(a, size),
        expand(b, size))

In [90]:
def test_broadcast_tensors():
    testcases = [
        ((1,), (1, 1)),

        ((2, 1, 2), (2, 1, 1, 3)),
        ((2, 1, 2), (5, 1)),
        ((2, 1, 2), (1, 5)),
        ((2, 1, 2), (1,)),
        ((2, 1, 2), (5, 3)),

        ((2, 1, 1, 3), (5, 1)),
        ((2, 1, 1, 3), (1, 5)),
        ((2, 1, 1, 3), (1,)),
        ((2, 1, 1, 3), (5, 3)),

        ((5, 1), (1, 5)),
        ((5, 1), (1, )),
        ((5, 1), (5, 3)),

        ((1, 5), (1, )),
        ((1, 5), (5, 3)),
        ((1, ), (5, 3)),
    ]

    for i, (a_shape, b_shape) in enumerate(testcases):
        a = torch.rand(a_shape)
        b = torch.rand(b_shape)

        error_custom = False
        error_torch = False

        try:
            torch_broadcast = torch.broadcast_tensors(a, b)
        except:
            error_torch = True

        try:
            custom_broadcast = broadcast_tensors(a, b)
        except:
            error_custom = True

        assert error_torch == error_custom

        if not error_torch:
            assert torch.equal(torch_broadcast[0], custom_broadcast[0]), "First tensors of Torch and Custom implementation are not the same"
            assert torch.equal(torch_broadcast[1], custom_broadcast[1]), "Second tensors of Torch and Custom implementation are not the same"

test_broadcast_tensors()