In [None]:
! nvidia-smi

Mon Jul 12 16:59:40 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
! pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio===0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html

Looking in links: https://download.pytorch.org/whl/lts/1.8/torch_lts.html
Collecting torch==1.8.1+cu102
[?25l  Downloading https://download.pytorch.org/whl/lts/1.8/cu102/torch-1.8.1%2Bcu102-cp37-cp37m-linux_x86_64.whl (804.1MB)
[K     |████████████████████████████████| 804.1MB 23kB/s 
[?25hCollecting torchvision==0.9.1+cu102
[?25l  Downloading https://download.pytorch.org/whl/lts/1.8/cu102/torchvision-0.9.1%2Bcu102-cp37-cp37m-linux_x86_64.whl (17.3MB)
[K     |████████████████████████████████| 17.3MB 5.1MB/s 
[?25hCollecting torchaudio===0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/aa/55/01ad9244bcd595e39cea5ce30726a7fe02fd963d07daeb136bfe7e23f0a5/torchaudio-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 15.3MB/s 
[31mERROR: torchtext 0.10.0 has requirement torch==1.9.0, but you'll have torch 1.8.1+cu102 which is incompatible.[0m
Installing collected packages: torch, torchvision, torchaudio
  Found existing in

In [None]:
! pip install fastcore --upgrade -qq
! pip install fastai --upgrade -qq
! pip install git+https://github.com/thomasbrandon/mish-cuda -qq

[K     |████████████████████████████████| 61kB 5.7MB/s 
[K     |████████████████████████████████| 194kB 14.6MB/s 
[?25h  Building wheel for mish-cuda (setup.py) ... [?25l[?25hdone


In [None]:
from fastai.vision.all import *
import fastai
from sys import exit
from operator import itemgetter
import re
import torch
from torch.nn import functional as F
import numpy as np
from mish_cuda import MishCudaFunction

In [None]:
def scale(val, spec="#0.4G"):
    PREFIXES = np.array([c for c in u"yzafpnµm kMGTPEZY"])
    exp = np.int8(np.log10(np.abs(val)) // 3 * 3 * np.sign(val))
    val /= 10.**exp
    prefix = PREFIXES[exp//3 + len(PREFIXES)//2]
    return f"{val:{spec}}{prefix}"

def display_times(times):
    return f"{scale(times.mean())}s ± {scale(times.std())}s, {scale(times.min())}s, {scale(times.max())}s"

def profile_cuda(func, inp, n_repeat=100, warmup=10):
    fwd_times,bwd_times = [],[]
    for i in range(n_repeat + warmup):
        start,end = (torch.cuda.Event(enable_timing=True) for _ in range(2))
        start.record()
        res = func(inp)
        end.record()
        torch.cuda.synchronize()
        if i >= warmup: fwd_times.append(start.elapsed_time(end))
        start,end = (torch.cuda.Event(enable_timing=True) for _ in range(2))
        inp = inp.clone().requires_grad_()
        y = func(inp)
        l = y.mean()
        start.record()
        _ = torch.autograd.grad(l, inp)
        end.record()
        torch.cuda.synchronize()
        if i >= warmup: bwd_times.append(start.elapsed_time(end))
    return (np.array(fwd_times)/1000, # Elapsed time is in ms
            np.array(bwd_times)/1000)

mish_pt = lambda x: x.mul(torch.tanh(F.softplus(x)))

def profile(device='cuda', n_repeat=100, warmup=10, size='(16,10,256,256)', baseline=True, types='all'):
    if types == 'all': 
        dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
    else:
        if not hasattr(torch, types): exit("Invalid data type, expected torch type or 'all', got {types}")
        dtypes = [getattr(torch, types)]
    dev = torch.device(type=device)
    sz_str = size.replace(' ','')
    if not re.match(r"[\(\[]\d+(,\d+)*[\)\]]", sz_str):
        exit("Badly formatted size, should be a list or tuple such as \"(1,2,3)\".")
    sz = list(map(int, sz_str[1:-1].split(',')))
    print(f"Profiling over {n_repeat} runs after {warmup} warmup runs.")
    for dtype in dtypes:
        if len(dtypes) > 1:
            print(f"Testing on {dtype}:")
            ind = ' '
        else: ind = ''
        inp = torch.randn(*sz, dtype=dtype, device=dev)
        timings = []
        funcs = {}
        funcs.update(relu = torch.nn.functional.relu, 
                     leaky_relu = torch.nn.functional.leaky_relu,
                     softplus = torch.nn.functional.softplus,
                     silu_jit = fastai.layers.swish,
                     silu_native = torch.nn.functional.silu,
                     mish_naive = mish_pt,
                     mish_jit = fastai.layers.mish,
                     mish_cuda = MishCudaFunction.apply)
        if device=='cpu': funcs.pop('mish_cuda')
        max_name = max(map(len, funcs.keys())) + 6
        for (name,func) in funcs.items():
            if device=='cuda':
                if (name=='mish_cuda') and (dtype==torch.bfloat16):
                    pass
                else: fwd_times,bwd_times = profile_cuda(func, inp, n_repeat, warmup)
            print(ind+(name+'_fwd:').ljust(max_name) + display_times(fwd_times))
            print(ind+(name+'_bwd:').ljust(max_name) + display_times(bwd_times))
            torch.cuda.empty_cache()

In [None]:
profile('cuda')

Profiling over 100 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        96.16µs ± 2.513µs, 94.21µs, 107.5µs
 relu_bwd:        174.2µs ± 10.19µs, 159.7µs, 217.1µs
 leaky_relu_fwd:  99.91µs ± 23.17µs, 88.06µs, 293.9µs
 leaky_relu_bwd:  248.1µs ± 421.7µs, 155.6µs, 3.763ms
 softplus_fwd:    99.59µs ± 3.792µs, 96.26µs, 120.8µs
 softplus_bwd:    192.7µs ± 179.1µs, 158.7µs, 1.964ms
 silu_jit_fwd:    110.8µs ± 4.060µs, 106.5µs, 126.0µs
 silu_jit_bwd:    208.9µs ± 21.59µs, 186.4µs, 347.1µs
 silu_native_fwd: 94.23µs ± 4.252µs, 91.14µs, 111.6µs
 silu_native_bwd: 173.7µs ± 21.12µs, 154.6µs, 303.1µs
 mish_naive_fwd:  241.1µs ± 5.324µs, 235.5µs, 267.3µs
 mish_naive_bwd:  504.7µs ± 359.3µs, 463.9µs, 4.080ms
 mish_jit_fwd:    195.0µs ± 6.865µs, 189.4µs, 244.7µs
 mish_jit_bwd:    261.5µs ± 9.261µs, 254.0µs, 294.9µs
 mish_cuda_fwd:   118.8µs ± 4.640µs, 115.7µs, 149.5µs
 mish_cuda_bwd:   205.0µs ± 75.83µs, 173.1µs, 840.7µs
Testing on torch.bfloat16:
 relu_fwd:        85.61µs ± 3.666µs, 

In [None]:
profile('cuda', size='(64,10,256,256)')

Profiling over 100 runs after 10 warmup runs.
Testing on torch.float16:
 relu_fwd:        261.8µs ± 3.682µs, 258.0µs, 280.6µs
 relu_bwd:        540.9µs ± 707.3ns, 539.6µs, 542.7µs
 leaky_relu_fwd:  264.8µs ± 9.853µs, 259.1µs, 322.6µs
 leaky_relu_bwd:  584.7µs ± 226.3µs, 539.6µs, 2.373ms
 softplus_fwd:    301.7µs ± 8.958µs, 293.9µs, 343.0µs
 softplus_bwd:    583.2µs ± 173.1µs, 547.8µs, 2.006ms
 silu_jit_fwd:    585.6µs ± 7.832µs, 577.5µs, 624.6µs
 silu_jit_bwd:    2.127ms ± 304.0µs, 2.079ms, 4.933ms
 silu_native_fwd: 264.4µs ± 3.303µs, 261.1µs, 278.5µs
 silu_native_bwd: 543.1µs ± 955.3ns, 541.7µs, 549.9µs
 mish_naive_fwd:  839.8µs ± 12.17µs, 832.5µs, 949.2µs
 mish_naive_bwd:  1.756ms ± 20.98µs, 1.751ms, 1.964ms
 mish_jit_fwd:    852.6µs ± 7.428µs, 844.8µs, 892.9µs
 mish_jit_bwd:    2.904ms ± 145.6µs, 2.887ms, 4.352ms
 mish_cuda_fwd:   371.7µs ± 3.792µs, 368.6µs, 389.1µs
 mish_cuda_bwd:   647.7µs ± 3.224µs, 644.1µs, 655.4µs
Testing on torch.bfloat16:
 relu_fwd:        265.0µs ± 31.70µs, 