In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm

# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
    shape = x.shape
    x = x.reshape(shape[0], shape[1], -1)
    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
    x = x.reshape(shape)
    return x


class Snake1d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, channels, 1))

    def forward(self, x):
        return snake(x, self.alpha)

In [2]:
import torch
import torch.nn as nn
import time

data = torch.randn(1000, 1000, 1000)

snake_activation = Snake1d(channels=1000)

start = time.time()
output = snake_activation(data)
print("Elapsed time: ", time.time() - start)

print(output.shape)
print(output)

Elapsed time:  1.5212342739105225
torch.Size([1000, 1000, 1000])
tensor([[[-0.2855,  0.9024, -0.5099,  ..., -0.0109,  0.6167,  0.0272],
         [-0.2826, -0.6212, -0.2269,  ..., -0.2854, -0.2887, -2.5758],
         [ 2.5817, -0.1975,  1.8856,  ..., -0.2940,  0.4847,  0.0600],
         ...,
         [ 0.0997,  0.5875, -0.3303,  ..., -0.4338,  1.3538, -0.2609],
         [ 0.0909, -1.0304, -1.7340,  ..., -0.3382, -0.2657, -0.5859],
         [ 0.5640,  1.7057,  1.8762,  ..., -2.4711,  1.9306, -0.2854]],

        [[ 2.2684,  0.8988, -0.2947,  ..., -0.1066, -0.8019,  0.5693],
         [ 0.0729, -1.6316, -0.2045,  ...,  0.0066,  0.0162,  1.5935],
         [-0.2827,  0.1673,  2.4097,  ..., -0.6003,  0.8951, -0.2781],
         ...,
         [ 2.0701,  0.1920, -0.2922,  ..., -0.3188,  1.0702, -0.2519],
         [-0.9260,  1.2002, -0.2854,  ...,  0.8658,  1.2223,  0.8905],
         [-0.3053, -0.3240, -0.3774,  ...,  2.5315,  0.9245,  1.0406]],

        [[-1.4732,  1.0150, -0.2854,  ..., -0.3499,

In [3]:
from typing import Optional
import tvm
from tvm import relax
from tvm.relax.frontend import nn
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np

class Snake1d(nn.Module):
    def __init__(self, channels, dtype: Optional[str] = None):
        super().__init__()
        self.alpha = nn.Parameter((1, channels, 1), dtype)

    def forward(self, x: nn.Tensor):
        shape_x = x.shape
        x = nn.op.reshape(x, (shape_x[0], shape_x[1], -1))

        b, c, w = x.shape

        x = nn.op.tensor_expr_op(
            lambda x, alpha: te.compute(
                (b, c, w),
                lambda i, j, k: x[i, j, k] + 1 / (alpha[0, j, 0] + 1e-9) * te.power(te.sin(alpha[0, j, 0] * x[i, j, k]), 2),
                name="snake_compute",
            ),
            "snake",
            args=[x, self.alpha],
        )

        x = nn.op.reshape(x, shape_x)
        return x


mod_from_relax, params_from_relax = Snake1d(1000).export_tvm(
    {"forward": {"x": nn.spec.Tensor((1000, 1000, 1000), "float32")}}, debug=True
)
mod_from_relax.show()
print(params_from_relax)

[('alpha', Tensor([1, 1000, 1], "float32"))]


In [4]:
target = Target.from_device("metal")
print(target)
with target:
    mod = relax.transform.LegalizeOps()(mod_from_relax)
    mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(mod)
ex = relax.build(mod, target)
device = tvm.metal()

vm = relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(data, device=device)
params = [np.ones(param.shape).astype("float32") for _, param in params_from_relax]
params = [tvm.nd.array(param, device=device) for param in params]

start = time.time()
effects = vm["_initialize_effect"]()
output_tvm = vm["forward"](tvm_data, *effects, *params)[0].numpy()
print("Elapsed time: ", time.time() - start)

print(output_tvm.shape)
print(output_tvm)

metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32
Elapsed time:  0.694025993347168
(1000, 1000, 1000)
[[[-0.28546363  0.90242946 -0.509931   ... -0.01088286  0.6167153
    0.02722606]
  [-0.28259847 -0.6212003  -0.22687232 ... -0.28539106 -0.28865224
   -2.57582   ]
  [ 2.5817065  -0.19754422  1.8856481  ... -0.29395473  0.48471504
    0.05997137]
  ...
  [ 0.09973724  0.5875402  -0.33032048 ... -0.4338004   1.3538272
   -0.26093805]
  [ 0.09087335 -1.030388   -1.7340286  ... -0.33820736 -0.2656629
   -0.58592594]
  [ 0.56398153  1.7056551   1.8762392  ... -2.4711318   1.9306073
   -0.28538567]]

 [[ 2.268404    0.89875853 -0.29467398 ... -0.10659833 -0.80191743
    0.5692668 ]
  [ 0.07288542 -1.6315548  -0.20451894 ...  0.00658888  0.01615486
    1.5935036 ]
  [-0.28272206  0.16730908  2.4096553  ... -0.6002601   0.8950604
   -0.27805477]
  ...
  [ 2.0701475   0.19199567 -0.2922448  ... 

In [5]:
np.testing.assert_allclose(output.detach().numpy(), output_tvm, atol=1e-5, rtol=1e-5)