# Pytorch Broadcasting

In [None]:
#|hide
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(123)
torch.set_printoptions(precision=1, sci_mode=False, profile='short')

## Some simple tensors

Let's start with some simple tensors we can use to learn with:

In [None]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
y = torch.tensor([2,4,5])

In [None]:
x.shape, y.shape

(torch.Size([3, 2, 3]), torch.Size([3]))

x has 2 rows (first dimension) and 3 cols (second dimension)

## Broadcasting - 2 dimensions

Pytorch broadcasting takes 2 tensors and compares their dimensions, starting from the right to the left.  So:

- x has dim: `2, 3`
- y has dim: `   3`

now pytorch will align the most right dimensions and use y two times (because of the missing corresponding dimension 2 x has but y has not.  This always starts from the most inner dimension.

In [None]:
x * y # element wise addition

tensor([[ 2,  6, 12],
        [20, 30, 42]])

What if y has two dimensions but the first dimension is 1?

- x has dim: `2, 3`
- y has dim: `1, 3`

In [None]:
y = torch.tensor([[2,4,5]])

In [None]:
x * y

tensor([[ 2,  8, 15],
        [ 8, 20, 30]])

This is exactly the same result: y will be broadcasted.  Will this also work when y has 2 dimensions, just like x?  In that case, no broadcasting is needed:

In [None]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
y = torch.tensor([[2, 3, 4],
                  [5, 6, 7]])
x * y

tensor([[ 2,  6, 12],
        [20, 30, 42]])

What if x has 4 dimensions?  Will it still broadcast?

- x has dim: `4, 3`
- y has dim: `2, 3`

In [None]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 6, 5],
                  [4, 3, 2]])
y = torch.tensor([[2, 3, 4],
                  [5, 6, 7]])
# THIS DOES NOT WORK: x * y
# ERROR: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

This means: this will not work, we need to make sure that either the dimension is missing or the dimension is 1 for broadcasting to work.

## Broadcasting - more dimensions

What if we have two tensors with more than 2 dimentions?

In [None]:
x = torch.tensor([[[1, 2, 3],
                  [4, 5, 6]],
                 [[7, 6, 5],
                  [4, 3, 2]],
                 [[2, 3, 4],
                  [5, 6, 7]],
                 [[6, 5, 4],
                  [3, 2, 1]]])
y = torch.tensor([2, 3, 4])

In [None]:
x.shape, y.shape

(torch.Size([4, 2, 3]), torch.Size([3]))

What happens if we multiply this ($4 \times 2 \times 3$) by ($3$):

- x has dim: `4, 2, 3`
- y has dim: `      3`

In [None]:
z = x * y
z

tensor([[[ 2,  6, 12],
         [ 8, 15, 24]],

        [[14, 18, 20],
         [ 8,  9,  8]],

        [[ 4,  9, 16],
         [10, 18, 28]],

        [[12, 15, 16],
         [ 6,  6,  4]]])

In [None]:
z.shape

torch.Size([4, 2, 3])