In [6]:
%matplotlib inline

# Setup

Before we begin, we need to install torch if it isn’t already available.
https://pytorch.org/get-started/locally/

`conda install pytorch -c pytorch`

&nbsp;&nbsp;&nbsp;&nbsp;or
 
`pip install torch`

In [7]:
import torch

### Misc

We'll start by defining several helper functions which we'll use later.

In [8]:
import collections
import os
import textwrap

from IPython.display import Markdown, display
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


# We want to show certain threading effects, but 1 vs. several dozen
# is often too stark a contrast.
torch.set_num_threads(4)


def print_as_cpp(source: str):
    display(Markdown(f"```c++\n{source}\n```"))


def load_extension(name: str, code: str, fn_name: str):
   """Compile our implementation into an inline module.

   Normally we would modify ATen instead, however this allows us
   to show an example without having to build PyTorch from source.
   """
   from torch.utils.cpp_extension import load_inline
   return load_inline(
      name,
      code,
      extra_cflags=["-O2", "-g", "-mavx2", "-mfma"],
      functions=[fn_name])

def module_to_setup_str(m):
   """Handle importing `m` during Timer setup.

   This step is only necessary because we are using custom extensions for
   demonstration, rather than modifying and rebuilding PyTorch core.
   """
   module_dir, module_name = os.path.split(m.__file__)
   return textwrap.dedent(f"""
      import sys
      if not {repr(module_dir)} in sys.path:
         sys.path.append({repr(module_dir)})
      import {module_name[:-3]} as my_module
      """)

# Case study: a specialized implementation of `x + 1`

In this tutorial, we are going to define the `shift` function, and show how to take a systematic approach towards optimizing it. For simplicity, we will only consider float Tensors on CPU.

In [9]:
shift_impl_v0_src = """
// First attempt at a specialized implementation of `x + 1`
at::Tensor shift(const at::Tensor & x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");

    auto y = x.clone(at::MemoryFormat::Contiguous);
    auto y_ptr = x.data_ptr<float>();
    auto n = y.numel();
    for (int i = 0; i < n; i++) {
        *(y_ptr + i) += 1;
    }
    return y;
}
"""

print_as_cpp(shift_impl_v0_src)
shift_impl_v0 = load_extension("shift_impl_v0", shift_impl_v0_src, "shift")

```c++

// First attempt at a specialized implementation of `x + 1`
at::Tensor shift(const at::Tensor & x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");

    auto y = x.clone(at::MemoryFormat::Contiguous);
    auto y_ptr = x.data_ptr<float>();
    auto n = y.numel();
    for (int i = 0; i < n; i++) {
        *(y_ptr + i) += 1;
    }
    return y;
}

```

## Naive benchmark: timeit.Timer

### Note: this is just here as a placeholder to help me organize my thoughts.

In [10]:
import timeit

repeats = 5
sizes = (1, 1024, 16384)


def measure_native(n):
    num_runs, total_time = timeit.Timer(
        "x + 1", 
        setup=f"import torch;x = torch.ones(({n},))",
    ).autorange()
    return total_time / num_runs


def measure_cpp(n):
    num_runs, total_time = timeit.Timer(
        "shift(x)", 
        setup=f"import torch;x = torch.ones(({n},))",
        globals={"shift": shift_impl_v0.shift},
    ).autorange()
    return total_time / num_runs


for title, measure_fn in (("Native", measure_native), ("\n\nC++ Extension", measure_cpp)):
    print(f"{title}\n" + "".join([f"n = {i}".rjust(13) for i in sizes]) + "\n" + "-" * 13 * len(sizes))
    for _ in range(repeats):
        result_line = ""
        for n in sizes:
            result_line += f"{measure_fn(n) * 1e6:10.1f} us"
        print(result_line)

Native
        n = 1     n = 1024    n = 16384
---------------------------------------
       8.3 us       8.9 us      14.5 us
       8.2 us       8.8 us      14.4 us
       8.4 us       8.7 us      13.8 us
       8.3 us       8.5 us      14.0 us
       8.1 us       8.6 us      14.1 us


C++ Extension
        n = 1     n = 1024    n = 16384
---------------------------------------
       4.1 us       5.1 us      17.8 us
       4.2 us       5.1 us      21.7 us
       4.1 us       5.0 us      17.9 us
       4.1 us       5.0 us      18.4 us
       4.1 us       5.1 us      20.7 us


# Runtime aware: torch.utils.benchmark.Timer

In [29]:
from torch.utils.benchmark import Timer

timer = Timer(
    stmt="x + 1",
    setup="x = torch.ones((1,))",
)

# The torch utils Timer returns a Measurement object, which contains
# metadata about the run as well as replicates, if applicable.
print(timer.timeit(100), "\n")


m = Timer(
    stmt="x + 1",
    # Like timeit.Timer, initialization can be done using `setup=...` or `globals=...` (or both).
    globals={"x": torch.ones((1,))},
    
    # torch.utils.benchmark.Timer takes several additional annotation argument:
    #   label, sub_label, description, and env
    # These change the __repr__ measurements, and are used when grouping and displaying
    # measurements. (Discussed later.)
    label="Add one",
    sub_label="Generic implementation.",
)

print(timer.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x7f46e18717f0>
x + 1
  11.19 us
  1 measurement, 100 runs , 1 thread 

<torch.utils.benchmark.utils.common.Measurement object at 0x7f46e579e588>
x + 1
  8.30 us
  1 measurement, 100 runs , 1 thread


## Timer.blocked_autorange
### A mixture of timeit.Timer.repeat and timeit.Timer.autorange

While `timeit.Timer.autorange` takes a single continuous measurement of at least 0.2 seconds, `torch.utils.benchmark.blocked_autorange` takes many measurements whose times total at least 0.2 seconds (which can be changed by the `min_run_time` parameter) subject to the constraint that timing overhead is a small fraction of the overall measurement. This is acomplished by first running with an increasing number of runs per loop until the run time is much larger than measurement overhead (which also serves as a warm up), and then taking measurements until the target time is reached. This has the useful properties that it wastes less data, and allows us to take statistics in order to assess the reliability of measurements.

In [36]:
m = Timer(
    stmt="x + 1",
    setup="x = torch.ones((1,))",
).blocked_autorange()

# Results summarized by __repr__
print(m, "\n")

# Helper methods for statistics
print(f"Mean:   {m.mean * 1e6:6.1f} us")
print(f"Median: {m.median * 1e6:6.1f} us")
print(f"IQR:    {m.iqr * 1e6:6.1f} us")
print(f"Times:  {str(m.times[:2])[:-1]}, ..., {str(m.times[-2:])[1:]}")


<torch.utils.benchmark.utils.common.Measurement object at 0x7f46e182b9b0>
x + 1
  Median: 8.34 us
  IQR:    0.28 us (8.23 to 8.51)
  24 measurements, 1000 runs per measurement, 1 thread 

Mean:      8.6 us
Median:    8.3 us
IQR:       0.3 us
Times:  [8.493732661008834e-06, 8.615698665380478e-06, ..., 8.232343941926956e-06, 9.091291576623917e-06]


## Why runtime awareness matters
It's very easy to accidentally make an apples-to-oranges comparizon, such as comparing measurements with different numbers of threads, or forgetting to CUDA synchronize.

In [52]:
x = torch.ones((1024, 1024))

num_runs, total_time = timeit.Timer("x + 1", globals={"x": x}).autorange()
m0 = Timer("x + 1", globals={"x": x}).blocked_autorange()
m1 = Timer("x + 1", globals={"x": x}, num_threads=torch.get_num_threads()).blocked_autorange()

print(f"timeit.Timer:                   {total_time / num_runs * 1e6:6.0f} us")
print(f"torch Timer:                    {m0.mean * 1e6:6.0f} us")
print(f"torch Timer(num_threads=...):   {m1.mean * 1e6:6.0f} us")


timeit.Timer:                      111 us
torch Timer:                       370 us
torch Timer(num_threads=...):      106 us


# torch.utils.benchmark.Compare
Easy comparison of measurements.

In [56]:
from torch.utils.benchmark import Compare

results = []
for n in [1, 16, 256, 1024, 4096, 16384, 32768]:
    for num_threads in [1, 2, 4]:
        setup=f"x = torch.ones(({n},))"
        results.append(Timer(
            "x + 1",
            setup=setup,
            num_threads=num_threads,
            label="Shift operator",
            sub_label="Generic implementation.",
            description=str(n),
        ).blocked_autorange())


    results.append(Timer(
        "my_module.shift(x)",
        setup=(
            module_to_setup_str(shift_impl_v0) +
            setup
        ),
        # Custom C++ operator does not support parallelism.
        num_threads=1,
        label="Shift operator",
        sub_label="Custom C++ operator",
        description=str(n),
    ).blocked_autorange())

compare = Compare(results)
compare.print()

[------------------------------------- Shift operator ------------------------------------]
                               |   1   |   16  |  256  |  1024  |  4096  |  16384  |  32768
1 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  8.2  |  8.6  |  8.7  |  9.3   |  10.3  |   14.7  |   22.3
      Custom C++ operator      |  4.1  |  4.2  |  4.5  |  5.3   |   8.2  |   18.6  |   35.3
2 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  8.0  |  8.4  |  8.7  |  9.3   |  10.4  |   14.6  |   23.4
4 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  7.9  |  8.7  |  8.9  |  9.3   |  10.5  |   14.6  |   23.8

Times are in microseconds (us).



### With extra formatting

In [57]:
compare.trim_significant_figures()
compare.colorize()
compare.print()

[------------------------------------- Shift operator ------------------------------------]
                               |   1   |   16  |  256  |  1024  |  4096  |  16384  |  32768
1 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  8.2  |  [2m[91m8.6[0m[0m  |  8.7  |  9.3   |   10   |  [92m[1m  15 [0m[0m  |  [92m[1m  22 [0m[0m
      Custom C++ operator      |  [92m[1m4.1[0m[0m  |  [92m[1m4.2[0m[0m  |  [92m[1m4.5[0m[0m  |  [92m[1m5.3 [0m[0m  |  [92m[1m  8 [0m[0m  |    19   |    35 
2 threads: --------------------------------------------------------------------------------
      Generic implementation.  |  [92m[1m8.0[0m[0m  |  [92m[1m8.4[0m[0m  |  [92m[1m8.7[0m[0m  |  [92m[1m9.3 [0m[0m  |  [92m[1m 10 [0m[0m  |  [92m[1m  15 [0m[0m  |  [92m[1m  23 [0m[0m
4 threads: --------------------------------------------------------------------------------
      Gen

## Note: this was me just playing with TensorAccessor so I could test some stuff. I'm sure you'll have critiques.

In [53]:
shift_impl_v1_src = """
// First attempt at a specialized implementation of `x + 1`
at::Tensor shift(const at::Tensor & x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");
    
    // Use TensorAccessor to handle stride calculation.
    // This lets us skip copying `x`.
    auto flat_x = x.flatten();
    auto x_accessor = flat_x.accessor<float, 1>();

    auto y = at::empty(x.sizes(), x.options());
    auto y_ptr = y.data_ptr<float>();
    
    auto n = y.numel();
    for (int64_t i = 0; i < n; i++) {
        *(y_ptr + i) = x_accessor[i] + 1;
    }
    return y;
}
"""

print_as_cpp(shift_impl_v1_src)
shift_impl_v1 = load_extension("shift_impl_v1", shift_impl_v1_src, "shift")

```c++

// First attempt at a specialized implementation of `x + 1`
at::Tensor shift(const at::Tensor & x) {
    TORCH_CHECK(x.scalar_type() == at::kFloat, "shift requires a float input");
    
    // Use TensorAccessor to handle stride calculation.
    // This lets us skip copying `x`.
    auto flat_x = x.flatten();
    auto x_accessor = flat_x.accessor<float, 1>();

    auto y = at::empty(x.sizes(), x.options());
    auto y_ptr = y.data_ptr<float>();
    
    auto n = y.numel();
    for (int64_t i = 0; i < n; i++) {
        *(y_ptr + i) = x_accessor[i] + 1;
    }
    return y;
}

```