In [None]:
! nvidia-smi

Fri Jul  9 23:01:03 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    26W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
! pip install fastcore --upgrade -qq
! pip install fastai --upgrade -qq
! pip install git+https://github.com/thomasbrandon/mish-cuda -qq

[K     |████████████████████████████████| 61kB 6.7MB/s 
[K     |████████████████████████████████| 194kB 15.4MB/s 
[?25h  Building wheel for mish-cuda (setup.py) ... [?25l[?25hdone


In [None]:
from fastai.vision.all import *
import fastai
from sys import exit
from operator import itemgetter
import re
import torch
from torch.nn import functional as F
import numpy as np
from mish_cuda import MishCudaFunction

In [None]:
def scale(val, spec="#0.4G"):
    PREFIXES = np.array([c for c in u"yzafpnµm kMGTPEZY"])
    exp = np.int8(np.log10(np.abs(val)) // 3 * 3 * np.sign(val))
    val /= 10.**exp
    prefix = PREFIXES[exp//3 + len(PREFIXES)//2]
    return f"{val:{spec}}{prefix}"

def display_times(times):
    return f"{scale(times.mean())}s ± {scale(times.std())}s, {scale(times.min())}s, {scale(times.max())}s"

def profile_cuda(func, inp, n_repeat=100, warmup=10):
    fwd_times,bwd_times = [],[]
    for i in range(n_repeat + warmup):
        start,end = (torch.cuda.Event(enable_timing=True) for _ in range(2))
        start.record()
        res = func(inp)
        end.record()
        torch.cuda.synchronize()
        if i >= warmup: fwd_times.append(start.elapsed_time(end))
        start,end = (torch.cuda.Event(enable_timing=True) for _ in range(2))
        inp = inp.clone().requires_grad_()
        y = func(inp)
        l = y.mean()
        start.record()
        _ = torch.autograd.grad(l, inp)
        end.record()
        torch.cuda.synchronize()
        if i >= warmup: bwd_times.append(start.elapsed_time(end))
    return (np.array(fwd_times)/1000, # Elapsed time is in ms
            np.array(bwd_times)/1000)

mish_pt = lambda x: x.mul(torch.tanh(F.softplus(x)))

def profile(device='cuda', n_repeat=100, warmup=10, size='(16,10,256,256)', baseline=True, types='all'):
    if types == 'all': 
        dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
    else:
        if not hasattr(torch, types): exit("Invalid data type, expected torch type or 'all', got {types}")
        dtypes = [getattr(torch, types)]
    dev = torch.device(type=device)
    sz_str = size.replace(' ','')
    if not re.match(r"[\(\[]\d+(,\d+)*[\)\]]", sz_str):
        exit("Badly formatted size, should be a list or tuple such as \"(1,2,3)\".")
    sz = list(map(int, sz_str[1:-1].split(',')))
    print(f"Profiling over {n_repeat} runs after {warmup} warmup runs.")
    for dtype in dtypes:
        if len(dtypes) > 1:
            print(f"Testing on {dtype}:")
            ind = ' '
        else: ind = ''
        inp = torch.randn(*sz, dtype=dtype, device=dev)
        timings = []
        funcs = {}
        funcs.update(relu = torch.nn.functional.relu, 
                     leaky_relu = torch.nn.functional.leaky_relu,
                     softplus = torch.nn.functional.softplus,
                     silu_jit = fastai.layers.swish,
                     silu_native = torch.nn.functional.silu,
                     mish_naive = mish_pt,
                     mish_jit = fastai.layers.mish,
                     mish_cuda = MishCudaFunction.apply,
                     mish_native = torch.nn.functional.mish)
        if device=='cpu': funcs.pop('mish_cuda')
        max_name = max(map(len, funcs.keys())) + 6
        for (name,func) in funcs.items():
            if device=='cuda':
                if (name=='mish_cuda') and (dtype==torch.bfloat16):
                    pass
                else: fwd_times,bwd_times = profile_cuda(func, inp, n_repeat, warmup)
            print(ind+(name+'_fwd:').ljust(max_name) + display_times(fwd_times))
            print(ind+(name+'_bwd:').ljust(max_name) + display_times(bwd_times))
            torch.cuda.empty_cache()

In [None]:
profile('cuda')

Profiling over 100 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        100.4µs ± 4.316µs, 96.26µs, 116.7µs
 relu_bwd:        179.6µs ± 16.79µs, 161.8µs, 301.1µs
 leaky_relu_fwd:  95.32µs ± 3.966µs, 92.16µs, 120.8µs
 leaky_relu_bwd:  178.8µs ± 28.97µs, 158.7µs, 323.6µs
 softplus_fwd:    105.4µs ± 7.069µs, 100.4µs, 151.6µs
 softplus_bwd:    237.3µs ± 378.3µs, 171.0µs, 3.813ms
 silu_jit_fwd:    114.3µs ± 62.76µs, 99.33µs, 638.0µs
 silu_jit_bwd:    249.7µs ± 288.3µs, 169.0µs, 2.765ms
 silu_native_fwd: 86.78µs ± 2.632µs, 83.97µs, 100.4µs
 silu_native_bwd: 158.8µs ± 12.15µs, 147.5µs, 196.6µs
 mish_naive_fwd:  239.4µs ± 7.391µs, 232.4µs, 272.4µs
 mish_naive_bwd:  502.3µs ± 341.5µs, 463.9µs, 3.900ms
 mish_jit_fwd:    187.0µs ± 4.837µs, 182.3µs, 204.8µs
 mish_jit_bwd:    254.5µs ± 9.229µs, 248.8µs, 328.7µs
 mish_cuda_fwd:   116.1µs ± 5.289µs, 107.5µs, 136.2µs
 mish_cuda_bwd:   185.1µs ± 21.07µs, 169.0µs, 291.8µs
 mish_native_fwd: 84.57µs ± 3.283µs, 81.92µs, 101.4µs
 mish_nati

In [None]:
profile('cuda', size='(64,10,256,256)')

Profiling over 100 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        301.5µs ± 5.669µs, 295.9µs, 327.7µs
 relu_bwd:        607.6µs ± 291.4µs, 576.5µs, 3.507ms
 leaky_relu_fwd:  269.9µs ± 4.267µs, 259.1µs, 282.6µs
 leaky_relu_bwd:  575.4µs ± 237.3µs, 540.7µs, 2.936ms
 softplus_fwd:    294.8µs ± 3.801µs, 290.8µs, 314.4µs
 softplus_bwd:    549.6µs ± 1.638µs, 546.8µs, 555.0µs
 silu_jit_fwd:    296.4µs ± 8.637µs, 288.8µs, 337.9µs
 silu_jit_bwd:    593.3µs ± 200.4µs, 558.1µs, 2.018ms
 silu_native_fwd: 263.1µs ± 17.87µs, 256.0µs, 433.2µs
 silu_native_bwd: 542.7µs ± 927.0ns, 540.7µs, 544.8µs
 mish_naive_fwd:  834.5µs ± 4.775µs, 829.4µs, 858.1µs
 mish_naive_bwd:  1.753ms ± 7.297µs, 1.749ms, 1.824ms
 mish_jit_fwd:    630.8µs ± 8.030µs, 620.5µs, 661.5µs
 mish_jit_bwd:    950.3µs ± 14.23µs, 930.8µs, 1.042ms
 mish_cuda_fwd:   371.4µs ± 6.599µs, 365.6µs, 403.5µs
 mish_cuda_bwd:   652.3µs ± 30.78µs, 642.0µs, 913.4µs
 mish_native_fwd: 261.6µs ± 4.253µs, 257.0µs, 283.6µs
 mish_nati