# Broadcasting

Broadcasting is a powerful mechanism that allows PyTorch to work with arrays of different shapes when performing arithmetic operations. Frequently we have a smaller tensor and a larger tensor, and we want to use the smaller tensor multiple times to perform some operation on the larger tensor.

In [7]:
import torch

广播机制能够将一个 $1*n$ 的张量和 $m * n$ 的张量进行广播，将 $1*n$ 的张量广播为 $m * n$ 的张量，然后进行相应的操作。 

广播机制能够将一个 $m*1$ 的张量和 $m * n$ 的张量进行广播，将 $m * 1$ 的张量广播为 $m * n$ 的张量， 然后进行相应的操作。

In [11]:
x = torch.arange(15).reshape(5, 3)
print(x)

y = torch.tensor([1, 0, 1])
print(y)

a = x + y
print(a)

z = torch.tensor([1, 1, 1, 1, 1]).reshape(5, -1)
print(z)

b = x * z
print(b)

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14]])
tensor([1, 0, 1])
tensor([[ 1,  1,  3],
        [ 4,  4,  6],
        [ 7,  7,  9],
        [10, 10, 12],
        [13, 13, 15]])
tensor([[1],
        [1],
        [1],
        [1],
        [1]])
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14]])


对于 $1 * n$ 和 $m * 1$ 的张量，也能进行广播，$1 * n \rightarrow m * n, m * 1 \rightarrow m * n$，然后进行相应的操作。

In [10]:
x = torch.arange(5).reshape(1, -1)
y = torch.arange(6).reshape(-1, 1)
z = x + y
print(z)

tensor([[0, 1, 2, 3, 4],
        [1, 2, 3, 4, 5],
        [2, 3, 4, 5, 6],
        [3, 4, 5, 6, 7],
        [4, 5, 6, 7, 8],
        [5, 6, 7, 8, 9]])


如果两个张量的维度不同，广播机制会将低维度的张量广播为高维度张量: $(m,) \rightarrow (1, m) \rightarrow (..., m), (m, n) \rightarrow (1, m, n) \rightarrow (..., m, n)$ 。

In [17]:
# x的形状(2, 3)
x = torch.linspace(1, 6, 6, dtype=torch.int32).reshape(2, 3)
print(x)

# y的形状(4, 1, 1)
y = torch.tensor([1, 10, 1000, 10000]).reshape(-1, 1, 1)
print(y)

# x被广播为: (2, 3) -> (1, 2, 3) -> (4, 2, 3)
# y被广播为: (4, 1, 1) -> (4, 2, 3)
z = y * x
print(z)

tensor([[1, 2, 3],
        [4, 5, 6]], dtype=torch.int32)
tensor([[[    1]],

        [[   10]],

        [[ 1000]],

        [[10000]]])
tensor([[[    1,     2,     3],
         [    4,     5,     6]],

        [[   10,    20,    30],
         [   40,    50,    60]],

        [[ 1000,  2000,  3000],
         [ 4000,  5000,  6000]],

        [[10000, 20000, 30000],
         [40000, 50000, 60000]]])
