In [61]:
import torch

In [62]:
a = torch.randn(5,784)
b = torch.randn(784,10)

In [63]:
a, b

(tensor([[-0.7907,  1.0112, -0.8297,  ..., -0.3511,  0.4576,  0.1291],
         [ 0.1918, -0.7470, -0.3927,  ...,  0.0692,  0.4960,  0.5979],
         [-0.4534,  1.0865,  0.4031,  ..., -0.3946,  1.5603,  2.0023],
         [-0.1760, -2.7591,  1.0126,  ...,  0.7163,  0.5890,  0.1119],
         [-0.6466, -0.9701, -0.8949,  ..., -1.4830, -0.2630,  0.5843]]),
 tensor([[-0.4702, -1.0919, -1.0979,  ..., -0.2159,  0.2466, -1.2393],
         [-0.8525,  0.4899, -0.5190,  ..., -0.3861, -0.0492, -0.4300],
         [ 1.5356, -0.0636,  0.2483,  ...,  1.1509,  0.3055,  0.5325],
         ...,
         [-0.8914,  1.2675, -0.4003,  ..., -0.7347, -0.4699,  0.0584],
         [ 2.1745,  0.9923,  1.6917,  ..., -0.8474, -0.4459, -1.7510],
         [-0.6510, -0.5975, -0.6276,  ...,  0.9200,  0.6085, -0.3840]]))

In [160]:
%timeit -n 10 a@b

29.6 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Straightforward loops

In [69]:
def slow_matmul(a, b):
    assert a.shape[1] == b.shape[0]  # Otherwise you can't multiply a and b.
    rows_a = a.shape[0]
    common_dim = a.shape[1]
    cols_b = b.shape[1]
    output = torch.zeros(rows_a, cols_b)
    for i in range(rows_a):
        for j in range(cols_b):
            for k in range(common_dim):
                output[i,j] += a[i,k] * b[k,j]
    return output

In [70]:
# Comparing floats is a little fiddly due to precision and rounding errors!
torch.allclose(slow_matmul(a,b), a@b, rtol=1e-5, atol=1e-5)

True

In [71]:
%time slow_matmul(a,b)

CPU times: user 902 ms, sys: 72 µs, total: 902 ms
Wall time: 902 ms


tensor([[-25.3802, -39.0457,  -2.8763,   8.0384,  17.4018, -23.1719,  22.7924,
         -30.0939,  34.7306,  -8.4432],
        [ 14.3616,  -1.3186,   9.9665, -15.9873, -20.0919,  18.0664,  13.3102,
          19.4809,  21.0659,   7.5077],
        [-41.6030, -30.6106,   9.2916, -12.5558,  22.2696,  -5.0116,  22.1513,
          -4.4115,  39.1963, -63.4013],
        [ 18.7230,  15.5820,  25.9895, -51.9905, -10.6715, -19.1673, -15.5803,
           0.2417,  36.0474,  27.3397],
        [-15.7743, -14.9407,   5.5782,  39.4856,  21.0562,  29.1491,  16.8582,
          -3.1518, -29.5479,   3.2722]])

## With elementwise multiplication 

In [138]:
def elementwise_matmul(a, b):
    assert a.shape[1] == b.shape[0]  # Otherwise you can't multiply a and b.
    rows_a = a.shape[0]
    cols_b = b.shape[1]
    output = torch.zeros(rows_a, cols_b)
    for i in range(rows_a):
        for j in range(cols_b):
            output[i,j] = (a[i,:] * b[:,j]).sum()
    return output

In [139]:
torch.allclose(elementwise_matmul(a,b), a@b, rtol=1e-5, atol=1e-5)

True

In [140]:
%timeit -n 10 elementwise_matmul(a,b)

1.59 ms ± 662 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## With broadcasting

In [141]:
def broadcasting_matmul(a, b):
    assert a.shape[1] == b.shape[0]  # Otherwise you can't multiply a and b.
    rows_a = a.shape[0]
    cols_b = b.shape[1]
    output = torch.zeros(rows_a, cols_b)
    for i in range(rows_a):
        output[i] = (a[i,:].unsqueeze(-1) * b).sum(dim=0)
    return output

In [142]:
torch.allclose(broadcasting_matmul(a,b), a@b, rtol=1e-5, atol=1e-5)

True

In [143]:
%timeit -n 10 broadcasting_matmul(a,b)

456 µs ± 113 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## With einsum

In [76]:
def einsum_matmul(a, b):
    assert a.shape[1] == b.shape[0]  # Otherwise you can't multiply a and b.
    return torch.einsum("ik,kj->ij", a, b)

In [77]:
torch.allclose(einsum_matmul(a,b), a@b, rtol=1e-5, atol=1e-5)

True

In [147]:
%timeit -n 10 einsum_matmul(a,b)

115 µs ± 37.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
