In [1]:
%matplotlib inline

# Setup

Before we begin, we need to install torch if it isn’t already available.
https://pytorch.org/get-started/locally/

`conda install pytorch -c pytorch`

&nbsp;&nbsp;&nbsp;&nbsp;or
 
`pip install torch`

In [2]:
import torch

### Misc

We'll start by defining several helper functions which we'll use later.

In [3]:
import collections
import os
import textwrap

from IPython.display import Markdown, display
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


# We want to show certain threading effects, but 1 vs. several dozen
# is often too stark a contrast.
torch.set_num_threads(4)


def print_as_cpp(source: str):
    display(Markdown(f"```c++\n{source}\n```"))


def load_extension(name: str, code: str, fn_name: str):
   """Compile our implementation into an inline module.

   Normally we would modify ATen instead, however this allows us
   to show an example without having to build PyTorch from source.
   """
   from torch.utils.cpp_extension import load_inline
   return load_inline(
      name,
      code,
      extra_cflags=["-O2", "-g"],
      functions=[fn_name])

def module_to_setup_str(m):
   """Handle importing `m` during Timer setup.

   This step is only necessary because we are using custom extensions for
   demonstration, rather than modifying and rebuilding PyTorch core.
   """
   module_dir, module_name = os.path.split(m.__file__)
   return textwrap.dedent(f"""
      import sys
      if not {repr(module_dir)} in sys.path:
         sys.path.append({repr(module_dir)})
      import {module_name[:-3]} as my_module
      """)

# Case study: a specialized implementation of `x + 1`

In this tutorial, we are going to define the `shift` function, and show how to take a systematic approach towards optimizing it. For simplicity, we will only consider float Tensors on CPU.

In [4]:
shift_impl_v0_src = """
// First attempt at a specialized implementation of `x + 1`
at::Tensor shift(const at::Tensor & x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");

    auto y = x.clone(at::MemoryFormat::Contiguous);
    auto y_ptr = y.data_ptr<float>();
    auto n = y.numel();
    for (int i = 0; i < n; i++) {
        *(y_ptr + i) += 1;
    }
    return y;
}
"""

print_as_cpp(shift_impl_v0_src)
shift_impl_v0 = load_extension("shift_impl_v0", shift_impl_v0_src, "shift")

```c++

// First attempt at a specialized implementation of `x + 1`
at::Tensor shift(const at::Tensor & x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");

    auto y = x.clone(at::MemoryFormat::Contiguous);
    auto y_ptr = y.data_ptr<float>();
    auto n = y.numel();
    for (int i = 0; i < n; i++) {
        *(y_ptr + i) += 1;
    }
    return y;
}

```

## Naive benchmark: timeit.Timer

### Note: this is just here as a placeholder to help me organize my thoughts.

In [5]:
import timeit

repeats = 5
sizes = (1, 1024, 16384)


def measure_native(n):
    num_runs, total_time = timeit.Timer(
        "x + 1", 
        setup=f"import torch;x = torch.ones(({n},))",
    ).autorange()
    return total_time / num_runs


def measure_cpp(n):
    num_runs, total_time = timeit.Timer(
        "shift(x)", 
        setup=f"import torch;x = torch.ones(({n},))",
        globals={"shift": shift_impl_v0.shift},
    ).autorange()
    return total_time / num_runs


for title, measure_fn in (("Native", measure_native), ("\n\nC++ Extension", measure_cpp)):
    print(f"{title}\n" + "".join([f"n = {i}".rjust(13) for i in sizes]) + "\n" + "-" * 13 * len(sizes))
    for _ in range(repeats):
        result_line = ""
        for n in sizes:
            result_line += f"{measure_fn(n) * 1e6:10.1f} us"
        print(result_line)

Native
        n = 1     n = 1024    n = 16384
---------------------------------------
       7.7 us       8.9 us      14.7 us
       8.0 us       9.1 us      14.6 us
       7.9 us       8.6 us      14.4 us
       8.0 us       8.6 us      14.8 us
       8.0 us       9.2 us      14.4 us


C++ Extension
        n = 1     n = 1024    n = 16384
---------------------------------------
       4.1 us       5.0 us      21.2 us
       6.5 us       7.8 us      22.0 us
       4.3 us       5.2 us      17.8 us
       4.0 us       5.5 us      17.4 us
       4.2 us       5.0 us      17.4 us


# Runtime aware: torch.utils.benchmark.Timer

In [6]:
from torch.utils.benchmark import Timer

timer = Timer(
    stmt="x + 1",
    setup="x = torch.ones((1,))",
)

# The torch utils Timer returns a Measurement object, which contains
# metadata about the run as well as replicates, if applicable.
print(timer.timeit(100), "\n")


m = Timer(
    stmt="x + 1",
    # Like timeit.Timer, initialization can be done using `setup=...` or `globals=...` (or both).
    globals={"x": torch.ones((1,))},
    
    # torch.utils.benchmark.Timer takes several additional annotation argument:
    #   label, sub_label, description, and env
    # These change the __repr__ measurements, and are used when grouping and displaying
    # measurements. (Discussed later.)
    label="Add one",
    sub_label="Generic implementation.",
)

print(timer.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x7fb68cee64e0>
x + 1
  9.27 us
  1 measurement, 100 runs , 1 thread 

<torch.utils.benchmark.utils.common.Measurement object at 0x7fb68d100c50>
x + 1
  8.59 us
  1 measurement, 100 runs , 1 thread


## Timer.blocked_autorange
### A mixture of timeit.Timer.repeat and timeit.Timer.autorange

While `timeit.Timer.autorange` takes a single continuous measurement of at least 0.2 seconds, `torch.utils.benchmark.blocked_autorange` takes many measurements whose times total at least 0.2 seconds (which can be changed by the `min_run_time` parameter) subject to the constraint that timing overhead is a small fraction of the overall measurement. This is acomplished by first running with an increasing number of runs per loop until the run time is much larger than measurement overhead (which also serves as a warm up), and then taking measurements until the target time is reached. This has the useful properties that it wastes less data, and allows us to take statistics in order to assess the reliability of measurements.

In [7]:
m = Timer(
    stmt="x + 1",
    setup="x = torch.ones((1,))",
).blocked_autorange()

# Results summarized by __repr__
print(m, "\n")

# Helper methods for statistics
print(f"Mean:   {m.mean * 1e6:6.1f} us")
print(f"Median: {m.median * 1e6:6.1f} us")
print(f"IQR:    {m.iqr * 1e6:6.1f} us")
print(f"Times:  {str(m.times[:2])[:-1]}, ..., {str(m.times[-2:])[1:]}")


<torch.utils.benchmark.utils.common.Measurement object at 0x7fb68d116a90>
x + 1
  Median: 8.17 us
  IQR:    0.40 us (7.96 to 8.36)
  25 measurements, 1000 runs per measurement, 1 thread 

Mean:      8.2 us
Median:    8.2 us
IQR:       0.4 us
Times:  [8.042722940444947e-06, 7.6196156442165375e-06, ..., 8.217006921768188e-06, 8.396882563829423e-06]


## Why runtime awareness matters
It's very easy to accidentally make an apples-to-oranges comparizon, such as comparing measurements with different numbers of threads, or forgetting to CUDA synchronize.

In [8]:
x = torch.ones((1024, 1024))

num_runs, total_time = timeit.Timer("x + 1", globals={"x": x}).autorange()
m0 = Timer("x + 1", globals={"x": x}).blocked_autorange()
m1 = Timer("x + 1", globals={"x": x}, num_threads=torch.get_num_threads()).blocked_autorange()

print(f"timeit.Timer:                   {total_time / num_runs * 1e6:6.0f} us")
print(f"torch Timer:                    {m0.mean * 1e6:6.0f} us")
print(f"torch Timer(num_threads=...):   {m1.mean * 1e6:6.0f} us")


timeit.Timer:                      118 us
torch Timer:                       379 us
torch Timer(num_threads=...):      115 us


# torch.utils.benchmark.Compare
Easy comparison of measurements.

In [9]:
from torch.utils.benchmark import Compare

results = []
for n in [1, 16, 256, 1024, 4096, 16384, 32768]:
    for num_threads in [1, 2, 4]:
        setup=f"x = torch.ones(({n},))"
        results.append(Timer(
            "x + 1",
            setup=setup,
            num_threads=num_threads,
            label="Shift operator",
            sub_label="Generic implementation.",
            description=str(n),
        ).blocked_autorange())


    results.append(Timer(
        "my_module.shift(x)",
        setup=(
            module_to_setup_str(shift_impl_v0) +
            setup
        ),
        # Custom C++ operator does not support parallelism.
        num_threads=1,
        label="Shift operator",
        sub_label="Custom C++ operator",
        description=str(n),
    ).blocked_autorange())

compare = Compare(results)
compare.print()

[------------------------------------- Shift operator ------------------------------------]
                               |   1   |   16  |  256  |  1024  |  4096  |  16384  |  32768
1 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  7.8  |  7.7  |  8.3  |  8.4   |  9.5   |   14.5  |   21.3
      Custom C++ operator      |  3.9  |  4.0  |  4.6  |  4.9   |  8.1   |   18.5  |   32.7
2 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  7.8  |  7.8  |  8.4  |  8.6   |  9.5   |   17.1  |   22.5
4 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  7.4  |  7.7  |  8.3  |  8.5   |  9.6   |   14.5  |   22.4

Times are in microseconds (us).



### With extra formatting

In [10]:
compare.trim_significant_figures()
compare.colorize()
compare.print()

[------------------------------------- Shift operator ------------------------------------]
                               |   1   |   16  |  256  |  1024  |  4096  |  16384  |  32768
1 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  7.8  |  7.7  |  8.3  |  8.4   |   10   |  [92m[1m  14 [0m[0m  |  [92m[1m  21 [0m[0m
      Custom C++ operator      |  [92m[1m3.9[0m[0m  |  [92m[1m4.0[0m[0m  |  [92m[1m4.6[0m[0m  |  [92m[1m4.9 [0m[0m  |  [92m[1m 8  [0m[0m  |    19   |    33 
2 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  [92m[1m7.8[0m[0m  |  [92m[1m7.8[0m[0m  |  [92m[1m8.4[0m[0m  |  [92m[1m8.6 [0m[0m  |  [92m[1m 10 [0m[0m  |  [92m[1m  20 [0m[0m  |  [92m[1m  23 [0m[0m
4 threads: --------------------------------------------------------------------------------
      Generic implementati

# torch.utils.benchmark.Fuzzer
## More diverse inputs

We'll take a brief detour and use fuzzed inputs to discuss transpose and contiguous before returning to `shift`.

In [11]:
from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias

example_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter("k0", minval=1, maxval=1024 ** 2, distribution="loguniform"),
        FuzzedParameter("k1", distribution={1: 0.5, ParameterAlias("k0"): 0.5}, strict=True),
    ],
    tensors = [
        FuzzedTensor("x", size=("k0", "k1"), min_elements=128, max_elements=128 * 1024, probability_contiguous=0.6)
    ],
    seed=0,
)

results = []
for tensors, tensor_params, params in example_fuzzer.take(10):
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    for stmt in ("x.contiguous()", "x.t().contiguous()"):
        timer = Timer(
            stmt,
            globals=tensors, 
            label="2D transpose",
            description=stmt,
            sub_label=sub_label,
        )
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

[------------------------------- 2D transpose ------------------------------]
                                     |  x.contiguous()  |  x.t().contiguous()
1 threads: ------------------------------------------------------------------
      355    x 355  (discontiguous)  |      210000      |          2000      
      751    x 1                     |         200      |          2000      
      313    x 313                   |         200      |        160000      
      45851  x 1                     |         200      |          2000      
      146    x 146                   |         200      |         43000      
      15854  x 1                     |         200      |          2000      
      143    x 143  (discontiguous)  |       38000      |          2000      
      2709   x 1                     |         200      |          2000      
      312    x 312                   |         200      |        160000      
      5674   x 1                     |         200      |       

The fuzzed benchmarks reveal several noteworthy features:
* If a Tensor is already contiguous we do not need to construct a new Tensor and the operation is extremely cheap. O(100 ns)

* For N x 1 Tensors transpose requires that we create a new Tensor, but we can reuse the same buffer.

* For N x N tensors, either contiguous or transposed contiguous will be expensive depending on the underlying data layout.

## Canned fuzzers: back to our x+1 kernel
When benchmarking an op, there are a lot of things to consider: Dimensionality, contiguity (both layout and strides from slicing), broacasting, sizes, etc. While it's certainly possible to write your own fuzzer, it's nice if one doesn't have to. To that end, canned fuzzers are provided for unary and binary ops, and more will be added soon.

In [12]:
from torch.utils.benchmark.op_fuzzers import unary

results, descriptions = [], []
for i, (tensors, tensor_params, params) in enumerate(unary.UnaryOpFuzzer(seed=0).take(10)):
    x = tensors["x"]
    descriptions.append(f"{str(list(x.shape)):<20}{'' if tensor_params['x']['is_contiguous'] else ' (discontiguous)'}")
    timer = Timer(
        "x + 1",
        globals=tensors,
        label="Shift operator",
        sub_label="Generic implementation.",
        description=f"[{i}]",
    )
    results.append(timer.blocked_autorange())
    
    timer = Timer(
        "my_module.shift(x)",
        globals=tensors,
        setup=module_to_setup_str(shift_impl_v0),
        label="Shift operator",
        sub_label="Custom C++ operator",
        description=f"[{i}]",
    )
    results.append(timer.blocked_autorange())
    
compare = Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()

for i, d in enumerate(descriptions):
    print(f"[{i}] {d}")

[------------------------------------------------- Shift operator -------------------------------------------------]
                               |  [0]   |  [1]  |  [2]  |   [3]   |  [4]  |  [5]  |  [6]   |  [7]   |   [8]   |  [9]
1 threads: ---------------------------------------------------------------------------------------------------------
      Generic implementation.  |  [92m[1m4000[0m[0m  |  [92m[1m 15[0m[0m  |  [92m[1m160[0m[0m  |  [92m[1m40000[0m[0m  |  [92m[1m380[0m[0m  |  [92m[1m 94[0m[0m  |  [92m[1m4500[0m[0m  |  [92m[1m5600[0m[0m  |  80000  |  [92m[1m240[0m[0m
      Custom C++ operator      |  7800  |   17  |  [2m[91m450[0m[0m  |  49000  |  [2m[91m900[0m[0m  |  [2m[91m190[0m[0m  |  7000  |  7900  |  [92m[1m70000[0m[0m  |  [2m[91m570[0m[0m

Times are in microseconds (us).

[0] [16, 34, 8324]      
[1] [16384]             
[2] [157, 2695]          (discontiguous)
[3] [256, 2538, 20]     
[4] [960851]            
[5

### TODO: elaborate on why [2] and [7] are slower than expected.

In [13]:
shift_impl_v1_src = """
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>

// Second attempt at a specialized implementation of x + 1
at::Tensor shift(const at::Tensor& x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");

    auto result = at::empty_like(x);
    auto iter = at::TensorIterator::unary_op(result, x);
    at::native::cpu_kernel(iter, [](float xi) -> float { return xi + 1; });

    return result;
}
"""

print_as_cpp(shift_impl_v1_src)
shift_impl_v1 = load_extension("shift_impl_v1", shift_impl_v1_src, "shift")

```c++

#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>

// Second attempt at a specialized implementation of x + 1
at::Tensor shift(const at::Tensor& x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");

    auto result = at::empty_like(x);
    auto iter = at::TensorIterator::unary_op(result, x);
    at::native::cpu_kernel(iter, [](float xi) -> float { return xi + 1; });

    return result;
}

```

In [14]:
# Op fuzzers are deterministic.
for i, (tensors, tensor_params, params) in enumerate(unary.UnaryOpFuzzer(seed=0).take(10)):
    x = tensors["x"]
    d = f"{str(list(x.shape)):<20}{'' if tensor_params['x']['is_contiguous'] else ' (discontiguous)'}"
    assert d == descriptions[i]

    timer = Timer(
        "my_module.shift(x)",
        globals=tensors,
        setup=module_to_setup_str(shift_impl_v1),
        label="Shift operator",
        sub_label="Custom C++ operator (v1)",
        description=f"[{i}]",
    )
    results.append(timer.blocked_autorange())
    
compare = Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()

for i, d in enumerate(descriptions):
    print(f"[{i}] {d}")

[-------------------------------------------------- Shift operator -------------------------------------------------]
                                |  [0]   |  [1]  |  [2]  |   [3]   |  [4]  |  [5]  |  [6]   |  [7]   |   [8]   |  [9]
1 threads: ----------------------------------------------------------------------------------------------------------
      Generic implementation.   |  [34m[1m4000[0m[0m  |  [34m[1m 15[0m[0m  |  [92m[1m160[0m[0m  |  [92m[1m40000[0m[0m  |  [92m[1m380[0m[0m  |  [92m[1m 94[0m[0m  |  4500  |  5600  |  80000  |  [92m[1m240[0m[0m
      Custom C++ operator       |  7800  |   17  |  [2m[91m450[0m[0m  |  49000  |  [2m[91m900[0m[0m  |  [2m[91m190[0m[0m  |  [2m[91m7000[0m[0m  |  7900  |  [92m[1m70000[0m[0m  |  [2m[91m570[0m[0m
      Custom C++ operator (v1)  |  [92m[1m4200[0m[0m  |  [92m[1m 14[0m[0m  |  240  |  44000  |  530  |  120  |  [92m[1m3300[0m[0m  |  [92m[1m5100[0m[0m  |  [34m[1m70000[0