In [4]:
import torch
import timeit
import torch.utils.benchmark as benchmark

In [20]:

def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to bmm'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(100000, 64)

print(batched_dot_mul_sum(x, x).size())

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))



torch.Size([100000])


In [19]:
# timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# 对比执行 1000 次求平均值
print(f'mul_sum(x, x):  {t0.timeit(1000) / 1000 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(1000) / 1000 * 1e6:>5.1f} us')

mul_sum(x, x):  6589.8 us
bmm(x, x):      1606.7 us


In [26]:

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x7ffdafc01ba8>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  6.35 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7ffdafc01b70>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  1.76 ms
  1 measurement, 100 runs , 1 thread


In [27]:
num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print(t0.timeit(100))
print(t1.timeit(100))

Benchmarking on 6 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7ffdafc01c18>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  6.32 ms
  1 measurement, 100 runs , 6 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7ffdafc01b38>
Multithreaded batch dot: Implemented using bmm
setup: from __main__ import batched_dot_bmm
  1.72 ms
  1 measurement, 100 runs , 6 threads
