In [20]:
import torch
import gc

from torch import nn
from torch.nn.utils import weight_norm

import timeit

torch.cuda.empty_cache()
gc.collect()

def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


conv1d = WNConv1d(128, 128, 7, dilation=1, padding=3).to('mps')

torch.manual_seed(0)
data = torch.randn(1, 128, 220500).to('mps')

# time
execution_time = timeit.timeit(lambda: conv1d(data), number=1000)

print(f"Average execution time: {execution_time / 1000:.6f} seconds per call")

Average execution time: 0.003750 seconds per call


In [21]:
import tvm

from tvm import relax
from tvm.relax.frontend import nn

from typing import Optional
from tvm import te
from tvm import dlight as dl
from tvm.target import Target
import numpy as np

from mlc_dac.layers import WNConv1d

conv1d = WNConv1d(128, 128, 7, dilation=1, padding=3)
mod, params = conv1d.export_tvm(
    {"forward": {"x": nn.spec.Tensor((1, 128, 220500), "float32")}}
)

target = Target.from_device("metal")
seq = tvm.transform.Sequential(
    [
        tvm.relax.transform.LegalizeOps(),
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FoldConstant(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
        dl.ApplyDefaultSchedule(
            dl.gpu.Matmul(),
            dl.gpu.GEMV(),
            dl.gpu.Reduction(),
            dl.gpu.GeneralReduction(),
            dl.gpu.Fallback(),
        ),
    ]
)
with target:
    mod = seq(mod)

mod.show()

In [24]:
ex = relax.build(mod, target)
device = tvm.metal()

np.random.seed(0)
vm = relax.VirtualMachine(ex, device, profile=True)
tvm_data = tvm.nd.array(data.cpu(), device=device)
tvm_params = [np.random.randn(*param.shape).astype("float32") for _, param in params]
tvm_params = [tvm.nd.array(param, device=device) for param in tvm_params]

# output_tvm = vm["forward"](tvm_data, *tvm_params)
# output_tvm = output_tvm.asnumpy()

# output_tvm

report = vm.profile("forward", tvm_data, *tvm_params)
csv = report.csv()

with open("profile_conv1d.csv", "w", encoding="utf-8") as f:
    f.write(csv)
    print("Profile saved to profile_conv1d.csv")


Profile saved to profile_conv1d.csv


In [26]:
vm_eval = relax.VirtualMachine(ex, device)
timing_res = vm_eval.time_evaluator("forward", device, number=3, repeat=10, min_repeat_ms=100)(tvm_data, *tvm_params)
print(timing_res)

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  11.1271      11.1268      11.1406      11.1159       0.0076                  
