### Video Link

TODO: add link

### Imports

In [136]:
from typing import Tuple
from torch import Tensor, tensor, zeros, ones

### Question A

In [137]:
BROADCAST_ERROR = "Cannot broadcast tensor A to tensor B's shape"

def broadcast_expand(A: Tensor, B: Tensor) -> Tensor:
    A_shape = list(A.shape)
    B_shape = list(B.shape)

    if len(A_shape) > len(B_shape):
        raise ValueError(BROADCAST_ERROR)

    A_shape = [1] * (len(B_shape) - len(A_shape)) + A_shape

    if not all(a_dim in [1, b_dim] for a_dim, b_dim in zip(A_shape, B_shape)):
        raise ValueError(BROADCAST_ERROR)

    expanded_A = ones(B_shape, dtype=A.dtype) * A
    return expanded_A

In [138]:
A = tensor([[1], [2], [3]])  # Shape (3,1)
B = zeros((3, 4))            # Shape (3,4)
expanded_A = broadcast_expand(A, B)
assert expanded_A.shape == B.shape, "Test case 1 failed"

A = tensor([1, 2, 3])        # Shape (3,)
B = zeros((3, 4))            # Shape (3,4)
expanded_A = broadcast_expand(A.view(3, 1), B)
assert expanded_A.shape == B.shape, "Test case 2 failed"

A = tensor([1])              # Shape (1,)
B = zeros((2, 3, 4))         # Shape (2,3,4)
expanded_A = broadcast_expand(A, B)
assert expanded_A.shape == B.shape, "Test case 3 failed"

try:
    A = tensor([1, 2])       # Shape (2,)
    B = zeros((3, 4))        # Shape (3,4)
    broadcast_expand(A, B)          # Should raise error
    assert False, "Test case 4 failed: Expected ValueError"
except ValueError:
    pass

print("All tests passed!")

All tests passed!


### Question B

In [139]:
def can_broadcast(A: Tensor, B: Tensor) -> Tuple[bool, Tuple[int]]:
    A_shape = list(A.shape)
    B_shape = list(B.shape)

    if len(A_shape) > len(B_shape):
        B_shape = [1] * (len(A_shape) - len(B_shape)) + B_shape
    else:
        A_shape = [1] * (len(B_shape) - len(A_shape)) + A_shape

    result_shape = []
    for a_dim, b_dim in zip(A_shape, B_shape):
        if a_dim not in [1, b_dim] and b_dim != 1:
            return False, None
        result_shape.append(max(a_dim, b_dim))

    return True, tuple(result_shape)

In [140]:
assert can_broadcast(tensor([1]), zeros((2, 3, 4))) == (True, (2, 3, 4)), "Test case 1 failed"
assert can_broadcast(tensor([[1], [2], [3]]), zeros((3, 4))) == (True, (3, 4)), "Test case 2 failed"
assert can_broadcast(tensor([1, 2]), zeros((3, 4))) == (False, None), "Test case 3 failed"
assert can_broadcast(tensor([[1, 2, 3]]), zeros((3,))) == (True, (1, 3)), "Test case 4 failed"
print("All tests passed!")

All tests passed!


### Question C

In [141]:
def broadcast_tensors(A: Tensor, B: Tensor) -> Tuple[Tensor]:
    broadcastable, result_shape = can_broadcast(A, B)
    if not broadcastable:
        raise ValueError("Tensors cannot be broadcasted together")

    expanded_A = broadcast_expand(A, zeros(result_shape))
    expanded_B = broadcast_expand(B, zeros(result_shape))
    return expanded_A, expanded_B

In [142]:
A = tensor([1])              # Shape (1,)
B = zeros((2, 3, 4))         # Shape (2,3,4)
expanded_A, expanded_B = broadcast_tensors(A, B)
assert expanded_A.shape == expanded_B.shape == (2, 3, 4), "Test case 1 failed"

A = tensor([[1], [2], [3]])  # Shape (3,1)
B = zeros((3, 4))            # Shape (3,4)
expanded_A, expanded_B = broadcast_tensors(A, B)
assert expanded_A.shape == expanded_B.shape == (3, 4), "Test case 2 failed"

try:
    A = tensor([1, 2])       # Shape (2,)
    B = zeros((3, 4))        # Shape (3,4)
    broadcast_tensors(A, B)        # Should raise error
    assert False, "Test case 3 failed: Expected ValueError"
except ValueError:
    pass

print("All tests passed!")

All tests passed!


### Question 4

In [143]:
from torch import allclose, broadcast_tensors as torch_broadcast_tensors

test_cases = [
    (tensor([1]), zeros((2, 3, 4)), "Broadcasting scalar to (2,3,4)"),
    (tensor([[1], [2], [3]]), zeros((3, 4)), "Column vector to (3,4)"),
    (tensor([1, 2, 3]), zeros((3, 1)), "Row vector to (3,1)"),
    (tensor([[1, 2, 3]]), zeros((1, 3)), "1-row matrix to (1,3)"),
    (tensor([1, 1, 1]), zeros((3, 3)), "1D vector broadcast to (3,3)"),
    (tensor([[1], [2]]), zeros((2, 3)), "Column vector to (2,3)"),
    (tensor([[1, 2, 3]]), zeros((2, 3)), "Row vector to (2,3)"),
    (tensor([[[1]], [[2]]]), zeros((2, 3, 4)), "(2,1,1) to (2,3,4)"),
    (tensor([1, 2, 3, 4]), zeros((1, 4)), "Vector to (1,4)"),
    (tensor([[1, 2], [3, 4]]), zeros((2, 2)), "No broadcasting needed"),
    (tensor([[1], [2], [3]]), zeros((1, 3, 4)), "Expanding from (3,1) to (1,3,4)"),
    (tensor([1, 2, 3]).view(3, 1, 1), zeros((3, 1, 2)), "Broadcasting (3,) to (3,1,2)"),
    (tensor([[1, 2, 3], [4, 5, 6]]).view(2, 3, 1), zeros((2, 3, 4)), "Expanding to third dim"),
    (tensor([1, 2, 3, 4]), zeros((3, 1, 4)), "Matching last dimension"),
    (tensor([[[1, 2, 3]]]), zeros((2, 3, 3)), "Expanding singleton dimension"),
]

try:
    for A, B, description in test_cases:
        expanded_A, expanded_B = broadcast_tensors(A, B)
        expected_A, expected_B = torch_broadcast_tensors(A, B)
        assert allclose(expanded_A, expected_A), "Mismatch in broadcasted A"
        assert allclose(expanded_B, expected_B), "Mismatch in broadcasted B"
except:
    print(f"Test '{description}' failed!")
else:
    print("All tests passed!")

All tests passed!
