In [1]:
import torch

def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

In [2]:
print(x.reshape(-1, 1, x.shape[-1]).shape)
print(x.reshape(-1, x.shape[-1], 1).shape)

torch.Size([10000, 1, 64])
torch.Size([10000, 64, 1])


### Using `timeit`

In [3]:
import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

mul_sum(x, x):   38.7 us
bmm(x, x):       76.8 us


### Using PyTorch Benchmark

- benchmark.Timer.timeit() returns the time per run as opposed to the total runtime like timeit.Timer.timeit() does.
- PyTorch benchmark module also provides formatted string representations for printing the results.
- Another important difference, and the reason why the results diverge is that PyTorch benchmark module runs in a `single thread by default`. We can change the number of threads with the num_threads argument.

In [4]:
import torch.utils.benchmark as benchmark

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x7f23183f33a0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  106.82 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f23183f2e30>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  364.75 us
  1 measurement, 100 runs , 1 thread


In [5]:
num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print(t0.timeit(100))
print(t1.timeit(100))

Benchmarking on 8 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f23183af7c0>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  23.19 us
  1 measurement, 100 runs , 8 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7f1fba1262f0>
Multithreaded batch dot: Implemented using bmm
setup: from __main__ import batched_dot_bmm
  60.78 us
  1 measurement, 100 runs , 8 threads


### Benchmarking on CUDA

In [6]:
x = torch.randn(10000, 1024, device='cuda')

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warm-up
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

mul_sum(x, x):  176.0 us
mul_sum(x, x):    7.8 us
bmm(x, x):      2546.3 us
bmm(x, x):       11.1 us


In [7]:
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Run only once since benchmark module does warm-up for us
print(t0.timeit(100))
print(t1.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x7f1febd66fb0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  111.67 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f23183f3610>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  845.03 us
  1 measurement, 100 runs , 1 thread


**Analysis**

The first run of the bmm version using the timeit module takes much longer than the second run.
This is because bmm calls into `cuBLAS` which needs to be loaded the first time it’s called which takes some time. 
This is why it’s important to do a `warm-up` run before benchmarking, luckily for us, PyTorch’s benchmark module takes care of that.

The difference in the results between timeit and benchmark modules is because 
the timeit module is not `synchronizing CUDA` and is thus only timing the time to launch the kernel. 
PyTorch’s benchmark module does the synchronization for us.

### Using `Blocked Autotrange`

In [9]:
m0 = t0.blocked_autorange()
m1 = t1.blocked_autorange()

print(m0)
print(m1)

<torch.utils.benchmark.utils.common.Measurement object at 0x7f23183f36a0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  Median: 66.76 us
  IQR:    0.37 us (66.56 to 66.94)
  4 measurements, 1000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f23183f2ce0>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  Median: 774.04 us
  3 measurements, 100 runs per measurement, 1 thread


In [10]:
print(f"Mean:   {m0.mean * 1e6:6.2f} us")
print(f"Median: {m0.median * 1e6:6.2f} us")

Mean:    66.74 us
Median:  66.76 us


### Comparing Benchmark Results

Over different inputs & parameters

In [12]:
from itertools import product

# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024, 10000]
for b, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Batched dot'
    sub_label = f'[{b}, {n}]'
    x = torch.ones((b, n))
    for num_threads in [1, 2, 4, 8]:
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='mul/sum',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='bmm',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

[--------------- Batched dot ---------------]
                      |  mul/sum   |    bmm  
1 threads: ----------------------------------
      [1, 1]          |       2.4  |      3.9
      [1, 64]         |       2.4  |      3.9
      [1, 1024]       |       2.5  |      4.1
      [1, 10000]      |       3.3  |      4.8
      [64, 1]         |       2.5  |      4.1
      [64, 64]        |       3.1  |      6.0
      [64, 1024]      |       8.3  |     59.0
      [64, 10000]     |      65.3  |    542.4
      [1024, 1]       |       2.8  |      7.1
      [1024, 64]      |      10.7  |     37.2
      [1024, 1024]    |     108.7  |    888.1
      [1024, 10000]   |   12489.0  |   8594.7
      [10000, 1]      |       6.0  |     34.2
      [10000, 64]     |      88.1  |    332.2
      [10000, 1024]   |   12395.3  |   8600.6
      [10000, 10000]  |  140360.1  |  84263.4
2 threads: ----------------------------------
      [1, 1]          |       2.3  |      3.9
      [1, 64]         |       2.4 

In [13]:
compare.trim_significant_figures()
compare.colorize()
compare.print()

[------------- Batched dot --------------]
                      |  mul/sum  |   bmm 
1 threads: -------------------------------
      [1, 1]          |  [92m[1m      2[0m[0m  |  [92m[1m    4[0m[0m
      [1, 64]         |  [92m[1m      2[0m[0m  |  [92m[1m    4[0m[0m
      [1, 1024]       |  [92m[1m      2[0m[0m  |  [34m[1m    4[0m[0m
      [1, 10000]      |        3  |      5
      [64, 1]         |  [34m[1m      2[0m[0m  |  [34m[1m    4[0m[0m
      [64, 64]        |        3  |      6
      [64, 1024]      |  [2m[91m      8[0m[0m  |  [31m[1m   59[0m[0m
      [64, 10000]     |  [31m[1m     70[0m[0m  |  [31m[1m  542[0m[0m
      [1024, 1]       |        3  |      7
      [1024, 64]      |  [2m[91m     11[0m[0m  |  [31m[1m   37[0m[0m
      [1024, 1024]    |  [31m[1m    100[0m[0m  |  [31m[1m  888[0m[0m
      [1024, 10000]   |  [31m[1m  12500[0m[0m  |  [31m[1m 8590[0m[0m
      [10000, 1]      |  [2m[91m      6[0m[0m

In [14]:
import pickle

ab_test_results = []
for env in ('environment A: mul/sum', 'environment B: bmm'):
    for b, n in ((1, 1), (1024, 10000), (10000, 1)):
        x = torch.ones((b, n))
        dot_fn = (batched_dot_mul_sum if env == 'environment A: mul/sum' else batched_dot_bmm)
        m = benchmark.Timer(
            stmt='batched_dot(x, x)',
            globals={'x': x, 'batched_dot': dot_fn},
            num_threads=1,
            label='Batched dot',
            description=f'[{b}, {n}]',
            env=env,
        ).blocked_autorange(min_run_time=1)
        ab_test_results.append(pickle.dumps(m))

ab_results = [pickle.loads(i) for i in ab_test_results]
compare = benchmark.Compare(ab_results)
compare.trim_significant_figures()
compare.colorize()
compare.print()

[------------------------------------- Batched dot -------------------------------------]
                                               |  [1, 1]  |  [1024, 10000]  |  [10000, 1]
1 threads: ------------------------------------------------------------------------------
  (environment A: mul/sum)  batched_dot(x, x)  |  [92m[1m  2   [0m[0m  |      13000      |  [92m[1m    6.0   [0m[0m
  (environment B: bmm)      batched_dot(x, x)  |    4     |  [92m[1m     8410    [0m[0m  |  [31m[1m   34.1   [0m[0m

Times are in microseconds (us).



In [16]:
# And just to show that we can round trip all of the results from earlier:
round_tripped_results = pickle.loads(pickle.dumps(results))
assert(str(benchmark.Compare(results)) == str(benchmark.Compare(round_tripped_results)))

### Fuzzed Parameters ( Generating inputs - automating input generation)

In [19]:
from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias

# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a
# ``loguniform`` distribution in [1, 10000], 40% of which will be discontiguous on average.
example_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter('k0', minval=1, maxval=10000, distribution='loguniform'),
        FuzzedParameter('k1', minval=1, maxval=10000, distribution='loguniform'),
    ],
    tensors = [
        FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=10000000, probability_contiguous=0.6)
    ],
    seed=0,
)

results = []
for tensors, tensor_params, params in example_fuzzer.take(10):
    # description is the column label
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()

[--------------------- Batched dot ---------------------]
                                     |  mul/sum  |   bmm 
1 threads: ----------------------------------------------
      725    x 257                   |     30    |    116
      49     x 383                   |      4    |     15
      34     x 1468                  |      7    |     46
      187    x 5039                  |    100    |    793
      2140   x 1296 (discontiguous)  |    300    |  10900
      78     x 1598                  |     16    |    108
      519    x 763                   |     44    |    338
      141    x 1082                  |     20    |    130
      78     x 5    (discontiguous)  |      3    |      6
      187    x 1                     |      3    |      4

Times are in microseconds (us).



In [20]:
# Using Built-in Fuzzers

from torch.utils.benchmark.op_fuzzers import binary

results = []
for tensors, tensor_params, params in binary.BinaryOpFuzzer(seed=0).take(10):
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()

[----------------------- Batched dot ------------------------]
                                         |  mul/sum  |   bmm  
1 threads: ---------------------------------------------------
      64     x 473  (discontiguous)      |  [92m[1m  4000 [0m[0m  |  [31m[1m 24000[0m[0m
      16384  x 12642115 (discontiguous)  |  [92m[1m     9 [0m[0m  |  [2m[91m    34[0m[0m
      8192   x 892                       |  [92m[1m   760 [0m[0m  |  [31m[1m  6100[0m[0m
      512    x 64   (discontiguous)      |  [92m[1m 33000 [0m[0m  |  [2m[91m123000[0m[0m
      493    x 27   (discontiguous)      |  [92m[1m   357 [0m[0m  |  [2m[91m   873[0m[0m
      118    x 32   (discontiguous)      |  [92m[1m   211 [0m[0m  |  [31m[1m  1060[0m[0m
      16     x 495  (discontiguous)      |  [92m[1m  1900 [0m[0m  |  [2m[91m  8890[0m[0m
      488    x 62374                     |   40300   |  [92m[1m 25300[0m[0m
      240372 x 69                        |  [2m[91

### Collecting Instruction Counts with `Callgrind`

One of the challenges of optimizing code is the variation and opacity of wall time. Furthermore, end-to-end time gives no insight into where time is being spent, which is really what we’re interested in when optimizing code.

A complementary approach is to also collect instruction counts. These counts are a proxy metric and do not capture all aspects of performance (e.g. memory or I/O bound tasks), however they do have several useful properties. Instruction counts are reproducible, insensitive to environmental variation, and offer fine grained insight into where a program is spending cycles.

In [22]:
# Implementing using both `reference` & `value`  

batched_dot_src = """\
/* ---- Python ---- */
// def batched_dot_mul_sum(a, b):
//     return a.mul(b).sum(-1)

torch::Tensor batched_dot_mul_sum_v0(
    const torch::Tensor a,
    const torch::Tensor b) {
  return a.mul(b).sum(-1);
}

torch::Tensor batched_dot_mul_sum_v1(
    const torch::Tensor& a,
    const torch::Tensor& b) {
  return a.mul(b).sum(-1);
}
"""


# PyTorch makes it easy to test our C++ implementations by providing a utility
# to JIT compile C++ source into Python extensions:
import os
from torch.utils import cpp_extension
cpp_lib = cpp_extension.load_inline(
    name='cpp_lib',
    cpp_sources=batched_dot_src,
    extra_cflags=['-O3'],
    extra_include_paths=[
        # `load_inline` needs to know where to find ``pybind11`` headers.
        # os.path.join(os.getenv('CONDA_PREFIX'), 'include')
    ],
    functions=['batched_dot_mul_sum_v0', 'batched_dot_mul_sum_v1']
)

# `load_inline` will create a shared object that is loaded into Python. When we collect
# instruction counts Timer will create a subprocess, so we need to re-import it. The
# import process is slightly more complicated for C extensions, but that's all we're
# doing here.
module_import_str = f"""\
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
import importlib.util
spec = importlib.util.spec_from_file_location("cpp_lib", {repr(cpp_lib.__file__)})
cpp_lib = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cpp_lib)"""

import textwrap
def pretty_print(result):
    """Import machinery for ``cpp_lib.so`` can get repetitive to look at."""
    print(repr(result).replace(textwrap.indent(module_import_str, "  "), "  import cpp_lib"))


t_baseline = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='''\
from __main__ import batched_dot_mul_sum
x = torch.randn(2, 2)''')

t0 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v0(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

t1 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v1(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

# Moving to C++ did indeed reduce overhead, but it's hard to tell which
# calling convention is more efficient. v1 (call with references) seems to
# be a bit faster, but it's within measurement error.
pretty_print(t_baseline.blocked_autorange())
pretty_print(t0.blocked_autorange())
pretty_print(t1.blocked_autorange())

<torch.utils.benchmark.utils.common.Measurement object at 0x7f1fba127eb0>
batched_dot_mul_sum(x, x)
setup:
  from __main__ import batched_dot_mul_sum
  x = torch.randn(2, 2)

  2.22 us
  1 measurement, 100000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f1fba127280>
cpp_lib.batched_dot_mul_sum_v0(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)

  Median: 1.87 us
  2 measurements, 100000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f1fba127bb0>
cpp_lib.batched_dot_mul_sum_v1(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)

  Median: 1.81 us
  2 measurements, 100000 runs per measurement, 1 thread


In [31]:
# Let's use ``Callgrind`` to determine which is better.
stats_v0 = t0.collect_callgrind()
stats_v1 = t1.collect_callgrind()

pretty_print(stats_v0)
pretty_print(stats_v1)

# `.as_standardized` removes file names and some path prefixes, and makes
# it easier to read the function symbols.
stats_v0 = stats_v0.as_standardized()
stats_v1 = stats_v1.as_standardized()

# `.delta` diffs the instruction counts, and `.denoise` removes several
# functions in the Python interpreter that are known to have significant
# jitter.
delta = stats_v1.delta(stats_v0).denoise()

# `.transform` is a convenience API for transforming function names. It is
# useful for increasing cancelation when ``diff-ing`` instructions, as well as
# just generally improving readability.
replacements = (
    ("???:void pybind11", "pybind11"),
    ("batched_dot_mul_sum_v0", "batched_dot_mul_sum_v1"),
    ("at::Tensor, at::Tensor", "..."),
    ("at::Tensor const&, at::Tensor const&", "..."),
    ("auto torch::detail::wrap_pybind_function_impl_", "wrap_pybind_function_impl_"),
)
for before, after in replacements:
    delta = delta.transform(lambda l: l.replace(before, after))

# We can use print options to control how much of the function to display.
torch.set_printoptions(linewidth=160)

# Once parsed, the instruction counts make clear that passing `a` and `b`
# by reference is more efficient as it skips some ``c10::TensorImpl`` bookkeeping
# for the intermediate Tensors, and is also works better with ``pybind11``. This
# is consistent with our noisy wall time observations.
print(delta)

OSError: Failed to collect callgrind profile:
Unknown error.
valgrind: python: command not found


In [None]:
## Facing issue with `Valgrind`, tried it using script but still didn't work.
## TODO:: Figure out `callgrind` workings

## By default `load_inline` stores the build in a tmp location. You can set it to be a custom dir ( checkout the .py file )