In [1]:
import torch
import math
import numpy as np

# Arithmetic operations

## I. matrix product和inner product

### I.1 matrix product
1. **torch.mm()**: 最基础的矩阵乘法，不做broadcasting
   - torch.mv()做matrix vector乘法
2. **torch.mul()**: elementwise multiplication with broadcasting
   - x.mul(x)相当于x * x
3. **torch.matmul()**: 可以用符号@表示，带broadcasting的tensor乘法
   - 如果两个参数都是1维，做dot product
   - 如果两个参数都是1维，做基础矩阵乘法，等价于torch.mm()
   - 如果第1个是matrix，第2个是vector，做matrix vector乘法，等价于torch.mv()
   - 如果第1个是vector, 第2个是matrix，按照broadcast机制做矩阵乘法
   - 如果其中一个维度高于2维，另一个至少有1维，做batched matrix products
     - 另一个2维或以上，这里2维是matrix dims，按照broadcast机制扩展维度，然后与高维的tensor做batched matrix products。
       - 如：t1的维度是(j,1,n,m)则(n,m)是matrix dims。t2是(k, m, p),则(m, p)是matrix dims.broadcast在(j,1)和(k)上做，得到$$
       (j\times k\times n\times m)\times(j\times k\times m\times p)=(j\times k\times n\times p)$$
     - 另一个1维，
       - 如：t1的维度是(j,1,n,m)则(n,m)是matrix dims。t2是(k, m),则(m)是vector dim.broadcast在(j,1)和(k)上做，得到$$
       (j\times k\times n\times m)\times(j\times k\times m)=(j\times k\times n)$$

In [30]:
a = torch.ones(16, dtype=torch.int).view(2, 1, 2, 4)
b = torch.arange(12, dtype=torch.int).view(4, 3)
a, b

(tensor([[[[1, 1, 1, 1],
           [1, 1, 1, 1]]],
 
 
         [[[1, 1, 1, 1],
           [1, 1, 1, 1]]]], dtype=torch.int32),
 tensor([[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]], dtype=torch.int32))

In [31]:
torch.matmul(a, b).shape, torch.matmul(a, b)

(torch.Size([2, 1, 2, 3]),
 tensor([[[[18, 22, 26],
           [18, 22, 26]]],
 
 
         [[[18, 22, 26],
           [18, 22, 26]]]], dtype=torch.int32))

In [32]:
c = torch.arange(4, dtype=torch.int)
torch.matmul(a, c).shape, torch.matmul(a, c)

(torch.Size([2, 1, 2]),
 tensor([[[6, 6]],
 
         [[6, 6]]], dtype=torch.int32))

### I.2 broadcast
规则：按照<font color=blue>**last to first**</font>的顺序比较每个dim的值，
 - 所有dim值相同的话，直接运算；
 - 不相同的dim上，其中一个operand的dim值是1；
 - 或者，dimension在其中一个tensor上不存在。

In [6]:
a =     torch.ones(2, 3, 2)
b = a * torch.rand(   3, 2) # 3rd & 2nd dims identical to a, dim 1 absent
print(b)

c = a * torch.rand(   3, 1) # 3rd dim = 1, 2nd dim identical to a
print(c)

d = a * torch.rand(   1, 2) # 3rd dim identical to a, 2nd dim = 1
print(d)

tensor([[[0.8827, 0.0272],
         [0.2219, 0.5566],
         [0.1124, 0.2677]],

        [[0.8827, 0.0272],
         [0.2219, 0.5566],
         [0.1124, 0.2677]]])
tensor([[[0.1467, 0.1467],
         [0.0363, 0.0363],
         [0.1443, 0.1443]],

        [[0.1467, 0.1467],
         [0.0363, 0.0363],
         [0.1443, 0.1443]]])
tensor([[[0.0023, 0.0992],
         [0.0023, 0.0992],
         [0.0023, 0.0992]],

        [[0.0023, 0.0992],
         [0.0023, 0.0992],
         [0.0023, 0.0992]]])


In [7]:
# 典型错误
a =     torch.ones(4, 3, 2)
# b = a * torch.rand(4, 3)    # dimensions must match last-to-first
# c = a * torch.rand(   2, 3) # both 3rd & 2nd dims different
# d = a * torch.rand((0, ))   # can't broadcast with an empty tensor

### I.3 inner product
#### inner和matrix product执行乘法规则时的区别：
 - inner的两个矩阵的第二维length要一样:$$(n,k)\times(m,k)=(n,m)$$
 - matrix乘法的两个矩阵是第一个的len(dim2)=第二个的len(dim1):$$(n,k)\times(k,m)=(n,m)$$
 - 此外，矩阵乘法有broadcast，inner product不涉及broadcast的问题

1. torch.dot()：只支持两个1D tensor的inner product
2. torch.inner(): 
   - 如果参数是两个scalar，做乘法
   - 如果参数都不是scalar，那么两个参数的最后一维的length必须相同。
     - 如果是两个1D tensor，和dot一样
     - 如果参数中至少1个的维度>=2，他们在最后一维上做inner product，<font color=blue>**前面的维度不需要broadcast规则**</font> $$
     (j,n,k)\times(l,m,k)=(j, l, n,m) \\
     (j,n,k)\times(l,k)=(j, l, n)
     $$

In [37]:
a = torch.ones(2, 3, dtype=torch.int)
b = torch.arange(24, dtype=torch.int).view(2, 4, 3)
b

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]], dtype=torch.int32)

In [38]:
torch.inner(a, b).shape, torch.inner(a, b)

(torch.Size([2, 2, 4]),
 tensor([[[ 3, 12, 21, 30],
          [39, 48, 57, 66]],
 
         [[ 3, 12, 21, 30],
          [39, 48, 57, 66]]], dtype=torch.int32))

In [None]:
torch.inner(a, torch.tensor(2))