In [2]:
from mlc_dac.dac import DAC

import tvm
from tvm import relax
from tvm.relax.frontend.torch import dynamo_capture_subgraphs
from tvm.relax.frontend.torch import from_fx
from tvm.script import relax as R

import torch

dac = DAC(512)
mod, named_params = dac.export_tvm(dac.get_default_spec(), debug=True)

In [3]:
def print_relax_funcnames(mod: tvm.IRModule):
    for global_var, func in mod.functions.items():
        if isinstance(func, relax.Function):
            print(global_var.name_hint)
    print()


print_relax_funcnames(mod)

_initialize_effect
decode
encode



In [5]:
print(mod.script(show_meta=True), file=open("../dist/before_scheduling.py", "w+"))

import pickle
import tvm.dlight as dl

pickle.dump(mod, open("../dist/before_scheduling.pkl", "wb"))

target = tvm.target.Target("apple/m2-gpu")

with target:
    seq = tvm.transform.Sequential(
        [
            relax.get_pipeline(),
            dl.ApplyDefaultSchedule(
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )
    mod = seq(mod)

ex = relax.build(mod, target=target)
ex.export_library("../dist/deploy-untuned-metal.so")

[22:22:22] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[22:22:35] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`


In [4]:
mod = pickle.load(open("../dist/after_scheduling.pkl", "rb"))
print(mod.script(show_meta=True), file=open("../dist/after_scheduling.py", "w+"))

In [5]:
for global_var, function in mod.functions.items():
    if isinstance(function, relax.Function):
        if global_var.name_hint.endswith("_transform_params"):
            print(
                global_var.name_hint,
                f' # <=== This is the weight parameter computation function for "{global_var.name_hint[:-17]}"',
            )
        else:
            print(global_var.name_hint)

_initialize_effect
decode
encode_transform_params  # <=== This is the weight parameter computation function for "encode"
decode_transform_params  # <=== This is the weight parameter computation function for "decode"
encode


In [6]:
from typing import Dict, List, Tuple

def split_transform_deploy_mod(
    mod: tvm.IRModule, model_names: List[str]
) -> Tuple[tvm.IRModule, tvm.IRModule]:
    mod_transform = tvm.IRModule()
    mod_deploy = tvm.IRModule()

    transform_func_names = [name + "_transform_params" for name in model_names]
    for gv in mod.functions:
        func = mod[gv]
        if isinstance(func, tvm.tir.PrimFunc):
            mod_transform[gv] = func
            mod_deploy[gv] = func
        elif gv.name_hint in transform_func_names:
            mod_transform[gv] = func
        else:
            mod_deploy[gv] = func

    mod_transform = relax.transform.DeadCodeElimination(transform_func_names)(
        mod_transform
    )
    mod_deploy = relax.transform.DeadCodeElimination(model_names)(mod_deploy)

    return mod_transform, mod_deploy

model_names = ["encode", "decode"]

mod_transform, mod_deploy = split_transform_deploy_mod(
    mod, model_names
)

In [7]:
print("In IRModule for build stage:")
print_relax_funcnames(mod_transform)

print("In IRModule for deployment stage:")
print_relax_funcnames(mod_deploy)

In IRModule for build stage:
encode_transform_params
decode_transform_params

In IRModule for deployment stage:
_initialize_effect
decode
encode



In [13]:
print(mod_deploy.script(show_meta=True), file=open("../dist/deploy.py", "w+"))

pickle.dump(mod_deploy, open("../dist/deploy.pkl", "wb"))

In [None]:
from tvm.runtime import Device
from tvm.relax.frontend.nn import Parameter
from tvm.contrib import tvmjs
import tvm.dlight as dl


def load_params(
    model_weight_path: str, device: Device, named_params: List[Tuple[str, Parameter]]
):
    params, _ = tvmjs.load_ndarray_cache(model_weight_path, device)
    param_names = [name for name, _ in named_params]

    plist = []
    for param_name in param_names:
        param_name = param_name.replace(".layers.", ".")
        param_name = param_name.replace(".branches.0.", ".")
        plist.append(params[param_name])
    return plist


device = tvm.metal()
params = load_params("../weights", device, named_params)


def transform_params(
    mod_transform: tvm.IRModule, model_params
) -> Dict[str, List[tvm.nd.NDArray]]:
    with tvm.target.Target("apple/m2-gpu"):
        mod_transform = dl.ApplyDefaultSchedule(
            dl.gpu.Fallback(),
        )(mod_transform)
    ex = relax.build(mod_transform, target="apple/m2-gpu")
    vm = relax.vm.VirtualMachine(ex, tvm.metal())
    new_params = dict()
    for name, params in model_params.items():
        new_params[name] = vm[name + "_transform_params"]([params])
    return new_params

new_params = transform_params(mod_transform, {"encode": params, "decode": params})

[16:38:19] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[16:38:19] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[16:38:20] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`


In [12]:
def save_params(params: Dict[str, List[tvm.nd.NDArray]], artifact_path: str) -> None:
    from tvm.contrib import tvmjs

    meta_data = {}
    param_dict = {}
    for model in ["encode", "decode"]:
        meta_data[f"{model}ParamSize"] = len(params[model])
        for i, nd in enumerate(params[model]):
            param_dict[f"{model}_{i}"] = nd
    tvmjs.dump_ndarray_cache(param_dict, f"{artifact_path}/params", meta_data=meta_data)

save_params(new_params, artifact_path="../dist")

Start storing to cache ../dist/params
[0241] saving decode_88 
All finished, 6 total shards committed, record saved to ../dist/params/ndarray-cache.json
Also saved a bf16 record to ../dist/params/ndarray-cache-b16.json


In [14]:
mod_deploy = pickle.load(open("../dist/deploy.pkl", "rb"))

target = tvm.target.Target("apple/m2-gpu")

with target:
    mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy)

ex = relax.build(mod_deploy, target=target)
ex.export_library("../dist/deploy-metal.so")

[16:50:45] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`
[16:50:47] /Users/cfruan/Documents/tvm-unity/src/target/llvm/llvm_instance.cc:226: Error: Using LLVM 19.1.1 with `-mcpu=apple-latest` is not valid in `-mtriple=arm64-apple-macos`, using default `-mcpu=generic`


In [4]:
from typing import Dict, List, Tuple

import numpy as np
import tvm
from tvm import relax
from tvm.contrib import tvmjs


def load_transformed_params(artifact_path: str, device) -> Dict[str, List[tvm.nd.NDArray]]:
    from tvm.contrib import tvmjs

    pdict = {}
    params, meta = tvmjs.load_ndarray_cache(f"{artifact_path}/params", device)
    for model in ["encode", "decode"]:
        plist = []
        size = meta[f"{model}ParamSize"]
        for i in range(size):
            plist.append(params[f"{model}_{i}"])
        pdict[model] = plist
    return pdict

device = tvm.metal()
const_params_dict = load_transformed_params("../dist", device)
ex = tvm.runtime.load_module("../dist/deploy-metal.so")

vm = relax.vm.VirtualMachine(ex, device, profile=True)
forward_fn = vm["encode"]

In [9]:
np.random.seed(0)
effects = vm["_initialize_effect"]()

audio_data = np.random.randn(1, 1, 512).astype("float32")
audio_data = tvm.nd.array(audio_data, device=device)
res, effects = forward_fn(audio_data, *effects, *const_params_dict["encode"])

time_eval = vm.time_evaluator("encode", device, 10, 5)(audio_data, *effects, *const_params_dict["encode"])
print(time_eval)

report = vm.profile("encode", audio_data, *effects, *const_params_dict["encode"])
csv = report.csv()

with open("profile_stream.csv", "w", encoding="utf-8") as f:
    f.write(csv)
    print("Profile saved to profile_stream.csv")

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  13.2346      12.9692      14.0863      12.9023       0.4425                  
Profile saved to profile_stream.csv
