In [1]:
import torch
import torch.utils.benchmark as benchmark

In [2]:
torch_mem_info = torch.cuda.mem_get_info()
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Globally available:', round(torch_mem_info[0]/1024**3,1), 'GB')
    print('Total:   ', round(torch_mem_info[1]/1024**3,1), 'GB')

Using device: cuda

NVIDIA GeForce RTX 3080
Memory Usage:
Globally available: 8.8 GB
Total:    9.8 GB


In [3]:
def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)

In [4]:
# Input for benchmarking
x = torch.randn(100000, 1024, device='cpu')
# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print('Benchmarking WSL single threaded CPU performance:')
print(t0.timeit(100))
print(t1.timeit(100))

Benchmarking WSL single threaded CPU performance:
<torch.utils.benchmark.utils.common.Measurement object at 0x7f4524e70340>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  112.08 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f45205a4be0>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  193.17 ms
  1 measurement, 100 runs , 1 thread


In [5]:
num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

# Input for benchmarking
x = torch.randn(100000, 1024, device='cpu')
# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print('Benchmarking WSL multi-threaded CPU performance:')
print(t0.timeit(100))
print(t1.timeit(100))

Benchmarking on 12 threads
Benchmarking WSL multi-threaded CPU performance:
<torch.utils.benchmark.utils.common.Measurement object at 0x7f4524ef5d30>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  52.23 ms
  1 measurement, 100 runs , 12 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f445b524c70>
Multithreaded batch dot: Implemented using bmm
setup: from __main__ import batched_dot_bmm
  29.07 ms
  1 measurement, 100 runs , 12 threads


In [6]:
# Input for benchmarking
x = torch.randn(100000, 1024, device=device)
# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Run only once since benchmark module does warm-up for us
print('Benchmarking WSL GPU (CUDA) performance:')
print(t0.timeit(100))
print(t1.timeit(100))

Benchmarking WSL GPU (CUDA) performance:
<torch.utils.benchmark.utils.common.Measurement object at 0x7f445b53adf0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  1.96 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f445b53a760>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  643.97 us
  1 measurement, 100 runs , 1 thread
