In [None]:
! nvidia-smi

Sun Jul 18 04:43:24 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
! pip install fastcore --upgrade -qq
! pip install fastai --upgrade -qq
! pip install git+https://github.com/thomasbrandon/mish-cuda -qq

[K     |████████████████████████████████| 61kB 5.7MB/s 
[K     |████████████████████████████████| 194kB 7.6MB/s 
[?25h  Building wheel for mish-cuda (setup.py) ... [?25l[?25hdone


In [None]:
from fastai.vision.all import *
import fastai
from sys import exit
from operator import itemgetter
import re
import torch
from torch.nn import functional as F
import numpy as np
from mish_cuda import MishCudaFunction

In [None]:
def scale(val, spec="#0.4G"):
    PREFIXES = np.array([c for c in u"yzafpnµm kMGTPEZY"])
    exp = np.int8(np.log10(np.abs(val)) // 3 * 3 * np.sign(val))
    val /= 10.**exp
    prefix = PREFIXES[exp//3 + len(PREFIXES)//2]
    return f"{val:{spec}}{prefix}"

def display_times(times):
    return f"{scale(times.mean())}s ± {scale(times.std())}s, {scale(times.min())}s, {scale(times.max())}s"

def profile_cuda(func, inp, n_repeat=100, warmup=10):
    fwd_times,bwd_times = [],[]
    for i in range(n_repeat + warmup):
        start,end = (torch.cuda.Event(enable_timing=True) for _ in range(2))
        start.record()
        res = func(inp)
        end.record()
        torch.cuda.synchronize()
        if i >= warmup: fwd_times.append(start.elapsed_time(end))
        start,end = (torch.cuda.Event(enable_timing=True) for _ in range(2))
        inp = inp.clone().requires_grad_()
        y = func(inp)
        l = y.mean()
        start.record()
        _ = torch.autograd.grad(l, inp)
        end.record()
        torch.cuda.synchronize()
        if i >= warmup: bwd_times.append(start.elapsed_time(end))
    return (np.array(fwd_times)/1000, # Elapsed time is in ms
            np.array(bwd_times)/1000)

mish_pt = lambda x: x.mul(torch.tanh(F.softplus(x)))

def profile(device='cuda', n_repeat=100, warmup=10, size='(16,10,256,256)', baseline=True, types='all'):
    if types == 'all': 
        dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
    else:
        if not hasattr(torch, types): exit("Invalid data type, expected torch type or 'all', got {types}")
        dtypes = [getattr(torch, types)]
    dev = torch.device(type=device)
    sz_str = size.replace(' ','')
    if not re.match(r"[\(\[]\d+(,\d+)*[\)\]]", sz_str):
        exit("Badly formatted size, should be a list or tuple such as \"(1,2,3)\".")
    sz = list(map(int, sz_str[1:-1].split(',')))
    print(f"Profiling over {n_repeat} runs after {warmup} warmup runs.")
    for dtype in dtypes:
        if len(dtypes) > 1:
            print(f"Testing on {dtype}:")
            ind = ' '
        else: ind = ''
        inp = torch.randn(*sz, dtype=dtype, device=dev)
        timings = []
        funcs = {}
        funcs.update(relu = torch.nn.functional.relu, 
                     leaky_relu = torch.nn.functional.leaky_relu,
                     softplus = torch.nn.functional.softplus,
                     silu_jit = fastai.layers.swish,
                     silu_native = torch.nn.functional.silu,
                     mish_naive = mish_pt,
                     mish_jit = fastai.layers.mish,
                     mish_cuda = MishCudaFunction.apply,
                     mish_native = torch.nn.functional.mish)
        if device=='cpu': funcs.pop('mish_cuda')
        max_name = max(map(len, funcs.keys())) + 6
        for (name,func) in funcs.items():
            if device=='cuda':
                if (name=='mish_cuda') and (dtype==torch.bfloat16):
                    pass
                else: fwd_times,bwd_times = profile_cuda(func, inp, n_repeat, warmup)
            print(ind+(name+'_fwd:').ljust(max_name) + display_times(fwd_times))
            print(ind+(name+'_bwd:').ljust(max_name) + display_times(bwd_times))
            torch.cuda.empty_cache()

# 100 iterations

In [None]:
profile('cuda')

Profiling over 100 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        114.3µs ± 15.35µs, 106.2µs, 236.0µs
 relu_bwd:        401.8µs ± 502.5µs, 327.6µs, 5.170ms
 leaky_relu_fwd:  104.9µs ± 15.58µs, 101.1µs, 257.4µs
 leaky_relu_bwd:  316.7µs ± 15.29µs, 314.0µs, 468.8µs
 softplus_fwd:    225.3µs ± 61.01µs, 214.5µs, 831.2µs
 softplus_bwd:    449.3µs ± 65.98µs, 436.4µs, 1.091ms
 silu_jit_fwd:    181.8µs ± 21.29µs, 177.0µs, 363.1µs
 silu_jit_bwd:    364.7µs ± 769.8ns, 362.7µs, 366.6µs
 silu_native_fwd: 128.5µs ± 4.818µs, 124.5µs, 148.4µs
 silu_native_bwd: 333.0µs ± 183.8µs, 309.5µs, 2.113ms
 mish_naive_fwd:  434.0µs ± 25.70µs, 428.3µs, 686.8µs
 mish_naive_bwd:  932.5µs ± 213.1µs, 900.4µs, 2.707ms
 mish_jit_fwd:    378.0µs ± 7.118µs, 368.3µs, 400.0µs
 mish_jit_bwd:    625.6µs ± 307.2µs, 590.4µs, 3.681ms
 mish_cuda_fwd:   277.4µs ± 3.832µs, 273.4µs, 305.4µs
 mish_cuda_bwd:   478.6µs ± 467.8ns, 477.8µs, 480.2µs
 mish_native_fwd: 206.1µs ± 44.46µs, 197.5µs, 644.7µs
 mish_nati

In [None]:
profile('cuda', size='(64,10,256,256)')

Profiling over 100 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        345.1µs ± 5.984µs, 337.7µs, 367.6µs
 relu_bwd:        1.257ms ± 201.6µs, 1.205ms, 2.986ms
 leaky_relu_fwd:  343.0µs ± 30.66µs, 334.0µs, 643.7µs
 leaky_relu_bwd:  1.213ms ± 59.66µs, 1.204ms, 1.806ms
 softplus_fwd:    791.0µs ± 4.405µs, 787.4µs, 808.7µs
 softplus_bwd:    1.638ms ± 127.4µs, 1.624ms, 2.906ms
 silu_jit_fwd:    609.4µs ± 6.577µs, 603.5µs, 637.8µs
 silu_jit_bwd:    1.430ms ± 12.35µs, 1.422ms, 1.550ms
 silu_native_fwd: 434.2µs ± 23.30µs, 426.4µs, 659.0µs
 silu_native_bwd: 1.239ms ± 300.0µs, 1.206ms, 4.223ms
 mish_naive_fwd:  1.623ms ± 9.061µs, 1.616ms, 1.690ms
 mish_naive_bwd:  3.512ms ± 620.0µs, 3.439ms, 9.614ms
 mish_jit_fwd:    1.362ms ± 8.539µs, 1.347ms, 1.408ms
 mish_jit_bwd:    2.329ms ± 4.493µs, 2.319ms, 2.335ms
 mish_cuda_fwd:   1.011ms ± 11.01µs, 1.004ms, 1.110ms
 mish_cuda_bwd:   1.883ms ± 1.018µs, 1.882ms, 1.891ms
 mish_native_fwd: 731.8µs ± 45.41µs, 720.1µs, 1.135ms
 mish_nati

# 50 iterations

In [None]:
profile('cuda', n_repeat=50)

Profiling over 50 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        108.2µs ± 4.445µs, 103.5µs, 124.4µs
 relu_bwd:        346.9µs ± 73.09µs, 330.0µs, 817.3µs
 leaky_relu_fwd:  103.6µs ± 3.178µs, 99.84µs, 118.0µs
 leaky_relu_bwd:  331.5µs ± 952.9ns, 329.8µs, 334.0µs
 softplus_fwd:    236.2µs ± 3.344µs, 223.6µs, 246.4µs
 softplus_bwd:    676.6µs ± 1.384ms, 438.5µs, 10.36ms
 silu_jit_fwd:    183.6µs ± 26.65µs, 175.3µs, 353.8µs
 silu_jit_bwd:    429.8µs ± 283.7µs, 363.4µs, 2.012ms
 silu_native_fwd: 125.1µs ± 1.105µs, 124.0µs, 131.4µs
 silu_native_bwd: 310.7µs ± 650.3ns, 309.7µs, 312.7µs
 mish_naive_fwd:  438.2µs ± 38.26µs, 426.9µs, 698.3µs
 mish_naive_bwd:  943.1µs ± 268.8µs, 901.3µs, 2.823ms
 mish_jit_fwd:    376.1µs ± 6.112µs, 367.2µs, 391.9µs
 mish_jit_bwd:    600.3µs ± 44.59µs, 590.9µs, 912.1µs
 mish_cuda_fwd:   275.8µs ± 4.643µs, 271.8µs, 299.7µs
 mish_cuda_bwd:   478.4µs ± 455.1ns, 477.4µs, 479.5µs
 mish_native_fwd: 198.4µs ± 2.961µs, 194.5µs, 210.6µs
 mish_nativ

In [None]:
profile('cuda', size='(64,10,256,256)', n_repeat=50)

Profiling over 50 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        345.7µs ± 22.92µs, 337.4µs, 500.6µs
 relu_bwd:        1.206ms ± 801.7ns, 1.205ms, 1.208ms
 leaky_relu_fwd:  339.2µs ± 6.653µs, 333.2µs, 367.8µs
 leaky_relu_bwd:  1.252ms ± 165.8µs, 1.205ms, 2.157ms
 softplus_fwd:    793.0µs ± 5.494µs, 786.9µs, 819.3µs
 softplus_bwd:    1.763ms ± 565.6µs, 1.624ms, 4.381ms
 silu_jit_fwd:    928.1µs ± 57.32µs, 911.5µs, 1.326ms
 silu_jit_bwd:    3.481ms ± 79.44µs, 3.465ms, 4.035ms
 silu_native_fwd: 435.2µs ± 18.88µs, 426.5µs, 563.2µs
 silu_native_bwd: 1.210ms ± 1.364µs, 1.207ms, 1.213ms
 mish_naive_fwd:  1.621ms ± 5.045µs, 1.614ms, 1.636ms
 mish_naive_bwd:  3.443ms ± 1.689µs, 3.440ms, 3.448ms
 mish_jit_fwd:    1.635ms ± 6.902µs, 1.628ms, 1.655ms
 mish_jit_bwd:    5.092ms ± 222.7µs, 5.057ms, 6.651ms
 mish_cuda_fwd:   1.007ms ± 4.318µs, 1.003ms, 1.023ms
 mish_cuda_bwd:   1.886ms ± 25.02µs, 1.882ms, 2.062ms
 mish_native_fwd: 725.3µs ± 5.760µs, 719.0µs, 745.0µs
 mish_nativ

# 25 iterations

In [None]:
profile('cuda', n_repeat=25, warmup=5)

Profiling over 25 runs after 5 warmup runs.
Testing on torch.float16:
 relu_fwd:        110.2µs ± 3.949µs, 104.6µs, 127.6µs
 relu_bwd:        332.9µs ± 1.380µs, 330.8µs, 336.4µs
 leaky_relu_fwd:  105.9µs ± 4.125µs, 101.5µs, 118.3µs
 leaky_relu_bwd:  333.0µs ± 909.3ns, 331.5µs, 334.8µs
 softplus_fwd:    239.3µs ± 4.721µs, 234.6µs, 254.6µs
 softplus_bwd:    485.2µs ± 851.7ns, 483.6µs, 487.3µs
 silu_jit_fwd:    195.0µs ± 4.587µs, 190.4µs, 208.9µs
 silu_jit_bwd:    396.1µs ± 1.312µs, 394.1µs, 399.6µs
 silu_native_fwd: 137.4µs ± 4.876µs, 132.9µs, 154.9µs
 silu_native_bwd: 336.5µs ± 1.428µs, 334.3µs, 339.2µs
 mish_naive_fwd:  463.4µs ± 4.393µs, 457.8µs, 475.7µs
 mish_naive_bwd:  951.8µs ± 1.520µs, 948.6µs, 954.6µs
 mish_jit_fwd:    382.8µs ± 13.21µs, 366.1µs, 410.5µs
 mish_jit_bwd:    602.1µs ± 6.681µs, 594.0µs, 613.6µs
 mish_cuda_fwd:   279.7µs ± 2.713µs, 276.9µs, 288.6µs
 mish_cuda_bwd:   481.9µs ± 494.5ns, 480.9µs, 483.0µs
 mish_native_fwd: 202.7µs ± 6.174µs, 196.9µs, 219.7µs
 mish_native

In [None]:
profile('cuda', size='(64,10,256,256)', n_repeat=25, warmup=5)

Profiling over 25 runs after 5 warmup runs.
Testing on torch.float16:
 relu_fwd:        343.4µs ± 5.465µs, 337.7µs, 362.5µs
 relu_bwd:        1.237ms ± 128.3µs, 1.209ms, 1.865ms
 leaky_relu_fwd:  352.0µs ± 55.62µs, 336.8µs, 623.7µs
 leaky_relu_bwd:  1.284ms ± 351.7µs, 1.210ms, 3.007ms
 softplus_fwd:    793.2µs ± 6.308µs, 786.1µs, 813.8µs
 softplus_bwd:    1.631ms ± 1.120µs, 1.629ms, 1.634ms
 silu_jit_fwd:    923.6µs ± 7.414µs, 912.9µs, 936.7µs
 silu_jit_bwd:    3.513ms ± 187.5µs, 3.470ms, 4.431ms
 silu_native_fwd: 432.8µs ± 5.688µs, 425.4µs, 450.3µs
 silu_native_bwd: 1.214ms ± 2.596µs, 1.211ms, 1.225ms
 mish_naive_fwd:  1.628ms ± 10.02µs, 1.617ms, 1.654ms
 mish_naive_bwd:  3.459ms ± 68.51µs, 3.443ms, 3.795ms
 mish_jit_fwd:    1.637ms ± 7.579µs, 1.629ms, 1.659ms
 mish_jit_bwd:    5.069ms ± 22.44µs, 5.062ms, 5.178ms
 mish_cuda_fwd:   1.011ms ± 6.176µs, 1.005ms, 1.037ms
 mish_cuda_bwd:   1.887ms ± 937.0ns, 1.885ms, 1.889ms
 mish_native_fwd: 725.0µs ± 3.944µs, 722.3µs, 738.5µs
 mish_native