In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm

torch.manual_seed(0)

# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
    shape = x.shape
    x = x.reshape(shape[0], shape[1], -1)
    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
    x = x.reshape(shape)
    return x


class Snake1d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, channels, 1))

    def forward(self, x):
        return snake(x, self.alpha)

In [2]:
import torch
import torch.nn as nn
import time

data = torch.randn(1000, 1000, 1000)

snake_activation = Snake1d(channels=1000)

start = time.time()
output = snake_activation(data)
print("Elapsed time: ", time.time() - start)

print(output.shape)
print(output)

Elapsed time:  1.5504400730133057
torch.Size([1000, 1000, 1000])
tensor([[[-0.3111, -0.3175, -0.1891,  ...,  0.2749,  1.0751,  0.7834],
         [ 0.4745,  0.8813,  0.0383,  ..., -0.2348,  0.0898, -0.2729],
         [-0.0769,  2.2389,  1.1538,  ...,  1.1998, -0.2854, -0.1621],
         ...,
         [ 2.6884,  2.6729,  0.1439,  ...,  2.7138,  0.8706,  1.7031],
         [-0.2212, -0.2856, -0.2735,  ..., -0.1797,  0.8550,  1.0921],
         [-0.2160, -0.1768, -0.1945,  ..., -0.2739, -0.0183, -0.0271]],

        [[ 0.2260,  0.7894, -2.3582,  ...,  1.7809, -0.2740,  0.2679],
         [ 0.1190,  1.1279,  0.2607,  ...,  1.4119,  1.3101,  0.1124],
         [ 1.3305,  2.7944, -0.2990,  ..., -0.2374, -0.3410, -0.2848],
         ...,
         [-0.2783,  0.5186,  1.4522,  ..., -0.1043,  1.9434,  2.7862],
         [-0.2862,  2.7186,  2.6963,  ...,  1.3948, -0.2653, -0.2763],
         [ 0.5810, -0.2649,  0.1593,  ..., -0.2854, -0.2772, -0.2845]],

        [[ 2.0975, -0.2862, -0.2888,  ..., -0.2949,

In [4]:
from typing import Optional
import tvm
from tvm import relax
from tvm.relax.frontend import nn
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np

class Snake1d(nn.Module):
    def __init__(self, channels, dtype: Optional[str] = None):
        super().__init__()
        self.alpha = nn.Parameter((1, channels, 1), dtype)

    def forward(self, x: nn.Tensor):
        shape_x = x.shape
        x = nn.op.reshape(x, (shape_x[0], shape_x[1], -1))

        b, c, w = x.shape

        x = nn.op.tensor_expr_op(
            lambda x, alpha: te.compute(
                (b, c, w),
                lambda i, j, k: x[i, j, k] + 1 / (alpha[0, j, 0] + 1e-9) * te.power(te.sin(alpha[0, j, 0] * x[i, j, k]), 2),
                name="snake_compute",
            ),
            "snake",
            args=[x, self.alpha],
        )

        x = nn.op.reshape(x, shape_x)
        return x


mod_from_relax, params_from_relax = Snake1d(1000).export_tvm(
    {"forward": {"x": nn.spec.Tensor((1000, 1000, 1000), "float32")}}, debug=True
)
mod_from_relax.show()
print(params_from_relax)

[('alpha', Tensor([1, 1000, 1], "float32"))]


In [5]:
target = Target.from_device("metal")
print(target)
with target:
    mod = relax.transform.LegalizeOps()(mod_from_relax)
    mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(mod)
ex = relax.build(mod, target)
device = tvm.metal()

vm = relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(data, device=device)
params = [np.ones(param.shape).astype("float32") for _, param in params_from_relax]
params = [tvm.nd.array(param, device=device) for param in params]

start = time.time()
effects = vm["_initialize_effect"]()
output_tvm = vm["forward"](tvm_data, *effects, *params)[0].numpy()
print("Elapsed time: ", time.time() - start)

print(output_tvm.shape)
print(output_tvm)

metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32
Elapsed time:  0.7192351818084717
(1000, 1000, 1000)
[[[-0.31110036 -0.3174659  -0.18909222 ...  0.274942    1.0750501
    0.7834442 ]
  [ 0.47453836  0.8813273   0.03826318 ... -0.23482114  0.08977266
   -0.27293307]
  [-0.07691161  2.2388506   1.1538075  ...  1.1998394  -0.28540426
   -0.16207218]
  ...
  [ 2.6884463   2.6729436   0.14385073 ...  2.7138073   0.87058294
    1.7030883 ]
  [-0.22119421 -0.28558815 -0.2735336  ... -0.17965236  0.8549696
    1.0920608 ]
  [-0.21595624 -0.17683856 -0.19451383 ... -0.2739008  -0.01827374
   -0.02706373]]

 [[ 0.22597633  0.7893648  -2.3581793  ...  1.780864   -0.2739746
    0.2679049 ]
  [ 0.11902615  1.1279249   0.26066315 ...  1.4119387   1.3100795
    0.11241668]
  [ 1.3304498   2.7944255  -0.2990268  ... -0.23743066 -0.34098005
   -0.28483236]
  ...
  [-0.2782858   0.51855546  1.4522367  ..

In [5]:
np.testing.assert_allclose(output.detach().numpy(), output_tvm, atol=1e-5, rtol=1e-5)

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm


def WNConv1d(*args, **kwargs):
    conv1d = nn.Conv1d(*args, **kwargs)
    print("original shape: ", conv1d.weight.shape)
    return weight_norm(conv1d)

In [7]:
wn_conv1d = WNConv1d(5, 10, 10)

torch.manual_seed(0)

weight_g_data = torch.randn(10, 1, 1)
weight_v_data = torch.randn(10, 5, 10)
bias_data = torch.randn(10)
wn_conv1d.weight_g.data = weight_g_data
wn_conv1d.weight_v.data = weight_v_data
wn_conv1d.bias.data = bias_data

print(wn_conv1d.weight_g.shape, wn_conv1d.weight_v.shape)
print(wn_conv1d.bias.shape)
data = torch.randn(20, 5, 1000)
out = wn_conv1d(data)
print(out.shape)
print(out)

original shape:  torch.Size([10, 5, 10])
torch.Size([10, 1, 1]) torch.Size([10, 5, 10])
torch.Size([10])
torch.Size([20, 10, 991])
tensor([[[ 2.1434e+00, -5.5163e-01,  1.8618e+00,  ...,  5.8507e-01,
           1.2631e+00,  3.9047e+00],
         [ 2.1691e-01,  1.2649e-01, -3.3990e-01,  ...,  2.6183e-02,
          -1.6413e-01, -7.1865e-01],
         [-8.9280e-02,  1.3261e-01, -1.6264e+00,  ...,  2.2662e-02,
           1.9311e+00, -5.4569e-01],
         ...,
         [ 4.1833e-02, -3.0315e-02, -1.4764e+00,  ..., -3.6672e-01,
          -1.7967e+00, -5.0222e-01],
         [ 1.7244e+00,  2.2790e+00,  4.6369e-01,  ...,  1.3250e+00,
           8.8640e-01,  1.9802e+00],
         [-5.5861e-01, -3.4076e-01, -6.6874e-01,  ..., -5.5540e-01,
          -1.1855e+00, -4.3219e-01]],

        [[ 3.9408e+00,  1.7900e+00,  1.3604e+00,  ...,  9.8031e-01,
           3.2566e-01,  3.1830e+00],
         [-3.9300e-01, -6.3680e-01, -4.0502e-01,  ..., -2.1011e-01,
           1.8975e-02, -1.6795e-01],
         [ 1.

  WeightNorm.apply(module, name, dim)


In [7]:
from typing import Optional
import tvm
from tvm import relax
from tvm.relax.frontend import nn
from tvm.relax import expr as _expr
from tvm.relax import op as _op
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np
import time

class WNConv1d(nn.Module):
    """
    Module for conv1d layer.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        dtype: Optional[str] = None,
    ) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.dtype = dtype

        self.weight_g = nn.Parameter(
            (
                self.out_channels,
                1,
                1,
            ),
            dtype,
        )

        self.weight_v = nn.Parameter(
            (
                self.out_channels,
                self.in_channels // self.groups,
                self.kernel_size,
            ),
            dtype,
        )
        if bias:
            self.bias = nn.Parameter((self.out_channels,), dtype)
        else:
            self.bias = None

    def forward(self, x: nn.Tensor) -> nn.Tensor:
        """
        Forward method for conv1d layer.

        Parameters
        ----------
        x : Tensor
            The input tensor.

        Returns
        -------
        ret : Tensor
            The output tensor for the conv1d layer.
        """
        dim = [i for i in range(1, x.ndim)]
        norm_v = _op.sqrt(
            _op.sum(_op.square(self.weight_v._expr), axis=dim, keepdims=True),
        )
        weight = nn.wrap_nested(
            self.weight_g._expr * (self.weight_v._expr / norm_v), name="wnconv1d"
        )
        return nn.op.conv1d(
            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

mod_from_relax, params_from_relax = WNConv1d(5, 10, 10).export_tvm(
    {"forward": {"x": nn.spec.Tensor((20, 5, 1000), "float32")}}, debug=True
)
mod_from_relax.show()
print(params_from_relax)

[('weight_g', Tensor([10, 1, 1], "float32")), ('weight_v', Tensor([10, 5, 10], "float32")), ('bias', Tensor([10], "float32"))]


In [14]:
target = Target.from_device("metal")
print(target)
with target:
    mod = relax.transform.LegalizeOps()(mod_from_relax)
    mod = dl.ApplyDefaultSchedule(
        # dl.gpu.Matmul(),
        # dl.gpu.GEMV(),
        # dl.gpu.Reduction(),
        # dl.gpu.GeneralReduction(),
        # dl.gpu.Fallback(),
    )(mod)
mod.show()
ex = relax.build(mod, target)
print(ex.mod.imported_modules[0].imported_modules[0].get_source())
device = tvm.metal()

vm = relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(data, device=device)
print(params_from_relax)
params = [weight_g_data.numpy().astype("float32"), weight_v_data.numpy().astype("float32"), bias_data.numpy().astype("float32")]
params = [tvm.nd.array(param, device=device) for param in params]

start = time.time()
effects = vm["_initialize_effect"]()
output_tvm = vm["forward"](tvm_data, *effects, *params)[0].numpy()
print("Elapsed time: ", time.time() - start)

# print(params)

# print(output_tvm.shape)
# print(output_tvm)

metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32


TVMError: Traceback (most recent call last):
  4: tvm::$_5::operator()(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target) const
        at /Users/cfruan/Documents/tvm-unity/src/driver/driver_api.cc:531
  3: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
        at /Users/cfruan/Documents/tvm-unity/src/driver/driver_api.cc:492
  2: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
        at /Users/cfruan/Documents/tvm-unity/src/driver/driver_api.cc:418
  1: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
        at /Users/cfruan/Documents/tvm-unity/src/driver/driver_api.cc:291
  0: tvm::tir::transform::VerifyMemory()::$_0::operator()(tvm::IRModule, tvm::transform::PassContext) const
        at /Users/cfruan/Documents/tvm-unity/src/tir/analysis/verify_memory.cc:205
  Did you forget to bind?
    Variable `lv2` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/Users/cfruan/Documents/tvm-unity/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def tir_sqrt(lv2: T.Buffer((T.int64(10), T.int64(1), T.int64(1)), "float32"), compute: T.Buffer((T.int64(10), T.int64(1), T.int64(1)), "float32")):
    T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "arm64-apple-darwin22.5.0", "tag": ""}, "keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
    for i0 in range(10):
        compute_1 = T.Buffer((T.int64(10),), data=compute.data)
        lv2_1 = T.Buffer((T.int64(10),), data=lv2.data)
        compute_1[i0] = T.sqrt(lv2_1[i0])

In [10]:
class ResidualUnit(nn.Module):
    def __init__(self, dim: int = 16, dilation: int = 1):
        pad = ((7 - 1) * dilation) // 2
        self.block = nn.ModuleList(
            [
                Snake1d(dim),
                WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
                Snake1d(dim),
                WNConv1d(dim, dim, kernel_size=1),
            ]
        )

    def forward(self, x: nn.Tensor):
        residual = x
        for layer in self.block:
            x = layer(x)
        return x + residual
    
mod_from_relax, params_from_relax = ResidualUnit().export_tvm(
    {"forward": {"x": nn.spec.Tensor((32, 16, 1000), "float32")}}, debug=True
)
mod_from_relax.show()
print(params_from_relax)

[('block.0.alpha', Tensor([1, 16, 1], "float32")), ('block.1.weight_g', Tensor([16, 1, 1], "float32")), ('block.1.weight_v', Tensor([16, 16, 7], "float32")), ('block.1.bias', Tensor([16], "float32")), ('block.2.alpha', Tensor([1, 16, 1], "float32")), ('block.3.weight_g', Tensor([16, 1, 1], "float32")), ('block.3.weight_v', Tensor([16, 16, 1], "float32")), ('block.3.bias', Tensor([16], "float32"))]


In [11]:
import math

class EncoderBlock(nn.Module):
    def __init__(self, dim: int = 16, stride: int = 1):
        self.block = nn.ModuleList(
            [
                ResidualUnit(dim // 2, dilation=1),
                ResidualUnit(dim // 2, dilation=3),
                ResidualUnit(dim // 2, dilation=9),
                Snake1d(dim // 2),
                WNConv1d(
                    dim // 2,
                    dim,
                    kernel_size=2 * stride,
                    stride=stride,
                    padding=math.ceil(stride / 2),
                ),
            ]
        )

    def forward(self, x):
        for layer in self.block:
            x = layer(x)
        return x
    
mod_from_relax, params_from_relax = EncoderBlock().export_tvm(
    {"forward": {"x": nn.spec.Tensor((32, 8, 1000), "float32")}}, debug=True
)
mod_from_relax.show()
print(params_from_relax)

[('block.0.block.0.alpha', Tensor([1, 8, 1], "float32")), ('block.0.block.1.weight_g', Tensor([8, 1, 1], "float32")), ('block.0.block.1.weight_v', Tensor([8, 8, 7], "float32")), ('block.0.block.1.bias', Tensor([8], "float32")), ('block.0.block.2.alpha', Tensor([1, 8, 1], "float32")), ('block.0.block.3.weight_g', Tensor([8, 1, 1], "float32")), ('block.0.block.3.weight_v', Tensor([8, 8, 1], "float32")), ('block.0.block.3.bias', Tensor([8], "float32")), ('block.1.block.0.alpha', Tensor([1, 8, 1], "float32")), ('block.1.block.1.weight_g', Tensor([8, 1, 1], "float32")), ('block.1.block.1.weight_v', Tensor([8, 8, 7], "float32")), ('block.1.block.1.bias', Tensor([8], "float32")), ('block.1.block.2.alpha', Tensor([1, 8, 1], "float32")), ('block.1.block.3.weight_g', Tensor([8, 1, 1], "float32")), ('block.1.block.3.weight_v', Tensor([8, 8, 1], "float32")), ('block.1.block.3.bias', Tensor([8], "float32")), ('block.2.block.0.alpha', Tensor([1, 8, 1], "float32")), ('block.2.block.1.weight_g', Tenso

In [12]:
class Encoder(nn.Module):
    def __init__(self,         
        d_model: int = 64,
        strides: list = [2, 4, 8, 8],
        d_latent: int = 64,
    ):

        self.block = [
            WNConv1d(1, d_model, kernel_size=7, padding=3)
        ]

        for stride in strides:
            d_model *= 2
            self.block += [EncoderBlock(d_model, stride=stride)]

        # Create last convolution
        self.block += [
            Snake1d(d_model),
            WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
        ]

        # Wrap black into nn.Sequential
        self.block = nn.ModuleList(self.block)

    def forward(self, x):
        for layer in self.block:
            x = layer(x)
        return x

In [13]:
from typing import List, Union


class DAC(nn.Module):
    def __init__(
        self,
        encoder_dim: int = 64,
        encoder_rates: List[int] = [2, 4, 8, 8],
        latent_dim: int = None,
        decoder_dim: int = 1536,
        decoder_rates: List[int] = [8, 8, 4, 2],
        n_codebooks: int = 9,
        codebook_size: int = 1024,
        codebook_dim: Union[int, list] = 8,
        quantizer_dropout: bool = False,
        sample_rate: int = 44100,
    ):
        self.encoder_dim = encoder_dim
        self.encoder_rates = encoder_rates
        self.decoder_dim = decoder_dim
        self.decoder_rates = decoder_rates
        self.sample_rate = sample_rate

        if latent_dim is None:
            latent_dim = encoder_dim * (2 ** len(encoder_rates))

        self.latent_dim = latent_dim

        self.hop_length = np.prod(encoder_rates)
        self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)

    def forward(self, x):
        x = self.encoder(x)
        return x

    
mod_from_relax, params_from_relax = DAC().export_tvm(
    {"forward": {"x": nn.spec.Tensor((32, 1, 100000), "float32")}}, debug=True
)
mod_from_relax.show()
print(params_from_relax)

[('encoder.block.0.weight_g', Tensor([64, 1, 1], "float32")), ('encoder.block.0.weight_v', Tensor([64, 1, 7], "float32")), ('encoder.block.0.bias', Tensor([64], "float32")), ('encoder.block.1.block.0.block.0.alpha', Tensor([1, 64, 1], "float32")), ('encoder.block.1.block.0.block.1.weight_g', Tensor([64, 1, 1], "float32")), ('encoder.block.1.block.0.block.1.weight_v', Tensor([64, 64, 7], "float32")), ('encoder.block.1.block.0.block.1.bias', Tensor([64], "float32")), ('encoder.block.1.block.0.block.2.alpha', Tensor([1, 64, 1], "float32")), ('encoder.block.1.block.0.block.3.weight_g', Tensor([64, 1, 1], "float32")), ('encoder.block.1.block.0.block.3.weight_v', Tensor([64, 64, 1], "float32")), ('encoder.block.1.block.0.block.3.bias', Tensor([64], "float32")), ('encoder.block.1.block.1.block.0.alpha', Tensor([1, 64, 1], "float32")), ('encoder.block.1.block.1.block.1.weight_g', Tensor([64, 1, 1], "float32")), ('encoder.block.1.block.1.block.1.weight_v', Tensor([64, 64, 7], "float32")), ('enc

In [14]:
import torch

file_path = "weights.pth"
state_dict = torch.load(file_path, map_location=torch.device('cpu'))["state_dict"]

# Check if the loaded data is a state dictionary
if isinstance(state_dict, dict):
    # Iterate over the state dictionary to print each weight name and its shape
    print("Weights and their corresponding shapes:")
    for name, tensor in state_dict.items():
        # Check if the item is a tensor to print its shape
        if isinstance(tensor, torch.Tensor):
            print(f"{name}: {tensor.shape}")
        else:
            print(f"{name}: Not a tensor, possibly another data type.")
else:
    print("Loaded file does not contain a state dictionary.")

Weights and their corresponding shapes:
encoder.block.0.bias: torch.Size([64])
encoder.block.0.weight_g: torch.Size([64, 1, 1])
encoder.block.0.weight_v: torch.Size([64, 1, 7])
encoder.block.1.block.0.block.0.alpha: torch.Size([1, 64, 1])
encoder.block.1.block.0.block.1.bias: torch.Size([64])
encoder.block.1.block.0.block.1.weight_g: torch.Size([64, 1, 1])
encoder.block.1.block.0.block.1.weight_v: torch.Size([64, 64, 7])
encoder.block.1.block.0.block.2.alpha: torch.Size([1, 64, 1])
encoder.block.1.block.0.block.3.bias: torch.Size([64])
encoder.block.1.block.0.block.3.weight_g: torch.Size([64, 1, 1])
encoder.block.1.block.0.block.3.weight_v: torch.Size([64, 64, 1])
encoder.block.1.block.1.block.0.alpha: torch.Size([1, 64, 1])
encoder.block.1.block.1.block.1.bias: torch.Size([64])
encoder.block.1.block.1.block.1.weight_g: torch.Size([64, 1, 1])
encoder.block.1.block.1.block.1.weight_v: torch.Size([64, 64, 7])
encoder.block.1.block.1.block.2.alpha: torch.Size([1, 64, 1])
encoder.block.1.b

  state_dict = torch.load(file_path, map_location=torch.device('cpu'))["state_dict"]
