# Pytorch Broadcasting

In [None]:
#|hide
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(123)
torch.set_printoptions(precision=1, sci_mode=False, profile='short')

## Some simple tensors

Let's start with some simple tensors we can use to learn with:

In [None]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
y = torch.tensor([2,4,5])

In [None]:
x.shape, y.shape

(torch.Size([2, 3]), torch.Size([3]))

x has 2 rows (first dimension) and 3 cols (second dimension)

## Broadcasting - 2 dimensions

Pytorch broadcasting takes 2 tensors and compares their dimensions, starting from the right to the left.  So:

- x has dim: `2, 3`
- y has dim: `   3`

now pytorch will align the most right dimensions and use y two times (because of the missing corresponding dimension 2 x has but y has not.  This always starts from the most inner dimension.

In [None]:
x * y # element wise addition

tensor([[ 2,  8, 15],
        [ 8, 20, 30]])

What if y has two dimensions but the first dimension is 1?

- x has dim: `2, 3`
- y has dim: `1, 3`

In [None]:
y = torch.tensor([[2,4,5]])

In [None]:
x * y

tensor([[ 2,  8, 15],
        [ 8, 20, 30]])

This is exactly the same result: y will be broadcasted.  Will this also work when y has 2 dimensions, just like x?  In that case, no broadcasting is needed:

In [None]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
y = torch.tensor([[2, 3, 4],
                  [5, 6, 7]])
x * y

tensor([[ 2,  6, 12],
        [20, 30, 42]])

What if x has 4 dimensions?  Will it still broadcast?

- x has dim: `4, 3`
- y has dim: `2, 3`

In [None]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 6, 5],
                  [4, 3, 2]])
y = torch.tensor([[2, 3, 4],
                  [5, 6, 7]])
# THIS DOES NOT WORK: x * y
# ERROR: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

This means: this will not work, we need to make sure that either the dimension is missing or the dimension is 1 for broadcasting to work.

## Broadcasting - more dimensions

What if we have two tensors with more than 2 dimentions?

In [None]:
x = torch.tensor([[[1, 2, 3],
                  [4, 5, 6]],
                 [[7, 6, 5],
                  [4, 3, 2]],
                 [[2, 3, 4],
                  [5, 6, 7]],
                 [[6, 5, 4],
                  [3, 2, 1]]])
y = torch.tensor([2, 3, 4])

In [None]:
x.shape, y.shape

(torch.Size([4, 2, 3]), torch.Size([3]))

What happens if we multiply this ($4 \times 2 \times 3$) by ($3$):

- x has dim: `4, 2, 3`
- y has dim: `      3`

In [None]:
z = x * y
z

tensor([[[ 2,  6, 12],
         [ 8, 15, 24]],

        [[14, 18, 20],
         [ 8,  9,  8]],

        [[ 4,  9, 16],
         [10, 18, 28]],

        [[12, 15, 16],
         [ 6,  6,  4]]])

In [None]:
z.shape

torch.Size([4, 2, 3])

In [None]:
y = torch.tensor([[2,3,4]])
z2 = x * y
torch.allclose(z, z2)

True

This last result means that when:

- x has dim: `4, 2, 3`
- y has dim: `   1, 3`

the result is exactly the same.

Let's try have a broadcast dimension of 1 in both tensors:

In [None]:
x = torch.tensor([
                 [[1, 2, 3]],
                 [[7, 6, 5]],
                 [[2, 3, 4]],
                 [[6, 5, 4]]
                ])
y = torch.tensor([[2, 1, 4],
                  [1, 3, 5],
                  [1, 2, 3]])

Now: 
    
- x has dim: `4, 1, 3`
- y has dim: `   3, 3`

So the result must be $4 \times 3 \times 3$

In [None]:
x * y

tensor([[[ 2,  2, 12],
         [ 1,  6, 15],
         [ 1,  4,  9]],

        [[14,  6, 20],
         [ 7, 18, 25],
         [ 7, 12, 15]],

        [[ 4,  3, 16],
         [ 2,  9, 20],
         [ 2,  6, 12]],

        [[12,  5, 16],
         [ 6, 15, 20],
         [ 6, 10, 12]]])

How did we get here?

The last dimension for x and y is equal, so we multiplied element wise the most inner elements.  However, along dimension with index 1, y has 3 times 3 numbers and x only has 1, so we need to broadcast x here and use `[1,2,3]` 3 times to match up with y.  So to match the shape with y, the dimension fo x with index 1 will first be duplicated 3 times:

In [None]:
x2 = torch.tensor([
                 [[1, 2, 3],
                  [1, 2, 3],
                  [1, 2, 3]],
    
                 [[7, 6, 5],
                  [7, 6, 5],
                  [7, 6, 5]],
    
                 [[2, 3, 4],
                  [2, 3, 4],
                  [2, 3, 4]],
    
                 [[6, 5, 4],
                  [6, 5, 4],
                  [6, 5, 4]]
                ])
y2 = torch.tensor([[2, 1, 4],
                  [1, 3, 5],
                  [1, 2, 3]])

Now, they have dimensions:

- x has dim: `4, 3, 3`
- y has dim: `   3, 3`

To completely match up, y will need to be repeated 4 times:

In [None]:
y2 = torch.tensor([
                   [[2, 1, 4],
                    [1, 3, 5],
                    [1, 2, 3]],
    
                    [[2, 1, 4],
                     [1, 3, 5],
                     [1, 2, 3]],
    
                    [[2, 1, 4],
                     [1, 3, 5],
                     [1, 2, 3]],
    
                    [[2, 1, 4],
                     [1, 3, 5],
                     [1, 2, 3]],
                 ])

In [None]:
x2 * y2

tensor([[[ 2,  2, 12],
         [ 1,  6, 15],
         [ 1,  4,  9]],

        [[14,  6, 20],
         [ 7, 18, 25],
         [ 7, 12, 15]],

        [[ 4,  3, 16],
         [ 2,  9, 20],
         [ 2,  6, 12]],

        [[12,  5, 16],
         [ 6, 15, 20],
         [ 6, 10, 12]]])

## Matrix Multiplication

Instead of doing an element-wise operation, let's try do matrix multiplication with pytorch.  The shorthand syntax for this is the `@` symbol.  

Let's first do multiplication of two 1 dimensional tensors.  This results in Pytorch doing a dot product:

In [None]:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = a @ b
c.shape, c

(torch.Size([]), tensor(32))

In [None]:
1*4 + 2*5 + 3*6

32

If both arguments are two-dimensional then the matrix multiplication is being performed:

In [None]:
x = torch.tensor([[1, 2],
                  [3, 4]])
y = torch.tensor([[2, 3],
                  [4, 5]])
x @ y

tensor([[10, 13],
        [22, 29]])

In [None]:
1*2+2*4, 1*3+2*5, 3*2+4*4, 3*3+4*5

(10, 13, 22, 29)

If the first argument only has 1 dimension while the second has 2 dimensions, Pytorch adds a first dimension of 1 (so the dimensions become: $1 \times 2 @ 2 \times 2 = 1 \times 2$), does the multiplication and removes the dimension again:

In [None]:
x = torch.tensor([1, 2])
y = torch.tensor([[2, 3],
                  [4, 5]])
x @ y

tensor([10, 13])

In [None]:
x2 = torch.tensor([[1, 2]])
x2 @ y  # but when pytorch does this, the first dimension of the result is removed again

tensor([[10, 13]])

If on the contrary the first tensor is two dimensional and the second is one dimensional, Pytorch does a [matrix-vector product](https://mathinsight.org/matrix_vector_multiplication):

In [None]:
x = torch.tensor([[2, 3],
                  [4, 5]])
y = torch.tensor([1, 2])
x @ y

tensor([ 8, 14])

In [None]:
2*1+3*2, 4*1+5*2

(8, 14)

This is essentially the same as if you would transpose the second tensor and do a matrix multiplication:

In [None]:
x @ y.T

tensor([ 8, 14])

## Resources

- [Pytorch documentation on broadcasting](https://pytorch.org/tutorials/beginner/introyt/tensors_deeper_tutorial.html#in-brief-tensor-broadcasting)
- [Pytorch broadcasting semantics](https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics)
- [Pytorch matrix multiplication](https://pytorch.org/docs/stable/generated/torch.matmul.html)