Skip to content

Commit

Permalink
[Unity][WEB] Relax vm on web runtime (apache#14131)
Browse files Browse the repository at this point in the history
This PR brings initial relax vm support on web runtime
  • Loading branch information
tqchen committed Feb 25, 2023
1 parent 82578c3 commit 678d01d
Show file tree
Hide file tree
Showing 16 changed files with 825 additions and 52 deletions.
4 changes: 4 additions & 0 deletions include/tvm/runtime/relax_vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#ifndef TVM_RUNTIME_RELAX_VM_VM_H_
#define TVM_RUNTIME_RELAX_VM_VM_H_

#ifndef TVM_RELAX_VM_ENABLE_PROFILER
#define TVM_RELAX_VM_ENABLE_PROFILER 1
#endif

#include <memory>
#include <string>
#include <vector>
Expand Down
119 changes: 119 additions & 0 deletions python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace to store utilities for building web runtime."""
# pylint: disable=unused-import
import sys
import os
import json
from typing import Mapping, Union

import numpy as np

import tvm
from .emcc import create_tvmjs_wasm


def _convert_f32_to_bf16(value):
cap = np.finfo("float32").max
assert -np.finfo("float32").max == np.finfo("float32").min
bf16_limit = ((np.array([cap.view("uint32")]) >> 16) << 16).view("float32")[0]
# When the value is in [-bf16_limit, bf16_limit], round to nearest even.
# We can afford to do it in dumping phase to reduce overall rounding error.
#
# When the value is out of bound(usually mask values in attention), use truncation
# so it is equivalent to clip to the limit values
data = value.view("uint32")
rounding_bias = np.where(
np.logical_and(value < bf16_limit, value > -bf16_limit),
((data >> 16) & 1) + 0x7FFF,
np.zeros_like(data),
)
return ((data + rounding_bias) >> 16).astype("uint16")


def dump_ndarray_cache(
params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
cachedir: str,
encode_format="f32-to-bf16",
):
"""Dump parameters to NDArray cache.
Parameters
----------
params: Mapping[str, tvm.runtime.NDArray],
The parameter dictionary
cachedir: str
The path to the cache
encode_format: {"f32-to-bf16", "raw"}
Encoding format.
"""
records = []
total = len(params)
counter = 0
max_out_length = 0

if not os.path.exists(cachedir):
os.makedirs(cachedir)

f32_to_bf16_triggered = False

print("Start storing to cache %s" % cachedir)
for k, v in params.items():
fname = k + ".bin"
out_path = os.path.join(cachedir, fname)
shape = list(v.shape)

if not isinstance(v, np.ndarray):
v = v.numpy()

# convert fp32 to bf16
if encode_format == "f32-to-bf16" and v.dtype == "float32":
_convert_f32_to_bf16(v).tofile(out_path)
dtype = "bfloat16"
f32_to_bf16_triggered = True
else:
v.tofile(out_path)

dtype = str(v.dtype)
records.append(
{"name": k, "shape": shape, "dtype": dtype, "dataPath": fname, "format": encode_format}
)
counter += 1
last_cmd = "[%04d/%04d] saving %s" % (counter, total, out_path)
flush = "\r" + (" " * max_out_length) + "\r"
max_out_length = max(len(last_cmd), max_out_length)
sys.stdout.write(flush + last_cmd)

nd_cache_json = os.path.join(cachedir, "ndarray-cache.json")
with open(nd_cache_json, "w") as outfile:
json.dump(records, outfile, indent=4)
print("\nAll finished, record saved to %s" % nd_cache_json)

if f32_to_bf16_triggered:
rec_bf16 = []
for item in records:
if item["dtype"] == "float32":
item["format"] = "raw"
item["dtype"] = "bfloat16"
rec_bf16.append(item)
b16_nd_cache_json = os.path.join(cachedir, "ndarray-cache-b16.json")
# also dump a file that contains bf16
with open(b16_nd_cache_json, "w") as outfile:
json.dump(rec_bf16, outfile, indent=4)
print("Also saved a bf16 record to %s" % b16_nd_cache_json)
32 changes: 23 additions & 9 deletions python/tvm/exec/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import argparse
import os
import glob
from tvm.rpc.proxy import Proxy


Expand All @@ -28,16 +29,29 @@ def find_example_resource():
base_path = os.path.abspath(os.path.join(curr_path, "..", "..", ".."))
index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html")
resource_files = [
os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"),
os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js"),
("/", os.path.join(base_path, "web", "dist", "tvmjs.bundle.js")),
("/", os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js")),
("/", index_page),
]
resource_base = os.path.join(base_path, "web", "dist", "www")
if os.path.isdir(resource_base):
for fname in os.listdir(resource_base):
full_name = os.path.join(resource_base, fname)
if os.path.isfile(full_name):
resource_files.append(full_name)
for fname in [index_page] + resource_files:
allow_format = ("json", "bin", "js", "wasm")

# recursively apend things in www, up to two levels
resource_bases = [
os.path.join(base_path, "web", "dist", "www"),
os.path.join(base_path, "web", ".ndarray_cache"),
]
for base in resource_bases:
if not os.path.isdir(base):
continue
for full_name in glob.glob("%s/**" % base, recursive=True):
fname = os.path.relpath(full_name, base)
dirname = os.path.dirname(fname)
fmt = fname.rsplit(".", 1)[-1]
if os.path.isfile(full_name) and fmt in allow_format:
resource_files.append((dirname, full_name))

for item in resource_files:
fname = item[-1]
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
return index_page, resource_files
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ def _vmcodegen(
raise ValueError("Unknown exec_mode %s" % exec_mode)


def _autodetect_system_lib_req(target: tvm.target.Target):
"""Automatically detect system lib requirement"""
host = target if target.host is None else target.host
system_lib = False
if "wasm" in host.attrs.get("mtriple", ""):
system_lib = True
if system_lib:
# use packed-func to avoid relay dep.
return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib})
return None


def _vmlink(
builder: "relax.ExecBuilder",
target: Union[str, tvm.target.Target],
Expand Down Expand Up @@ -224,7 +236,7 @@ def _vmlink(
ext_libs = []
lib = None
if tir_mod is not None:
lib = tvm.build(tir_mod, target=target)
lib = tvm.build(tir_mod, target=target, runtime=_autodetect_system_lib_req(target))
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore


Expand Down
21 changes: 19 additions & 2 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,20 @@ def signal_close(self):
self.close()


MIME_MAP = {
"js": "application/javascript",
"wasm": "application/wasm",
"json": "application/json",
}


class RequestHandler(tornado.web.RequestHandler):
"""Handles html request."""

def __init__(self, *args, **kwargs):
file_path = kwargs.pop("file_path")
self.format = file_path.split(".")[-1]

if file_path.endswith("html"):
self.page = open(file_path).read()
web_port = kwargs.pop("rpc_web_port", None)
Expand All @@ -217,12 +226,15 @@ def __init__(self, *args, **kwargs):
)
else:
self.page = open(file_path, "rb").read()

super(RequestHandler, self).__init__(*args, **kwargs)

def data_received(self, _):
pass

def get(self, *args, **kwargs):
if self.format in MIME_MAP:
self.set_header("Content-Type", MIME_MAP[self.format])
self.write(self.page)


Expand Down Expand Up @@ -254,9 +266,14 @@ def __init__(
)
logging.info("Serving RPC index html page at http://localhost:%d", web_port)
resource_files = resource_files if resource_files else []
for fname in resource_files:
for item in resource_files:
prefix, fname = item
if not prefix.endswith("/"):
prefix += "/"
if not prefix.startswith("/"):
prefix = "/" + prefix
basename = os.path.basename(fname)
pair = (r"/%s" % basename, RequestHandler, {"file_path": fname})
pair = (r"%s%s" % (prefix, basename), RequestHandler, {"file_path": fname})
handlers.append(pair)
logging.info(pair)
self.app = tornado.web.Application(handlers)
Expand Down
11 changes: 11 additions & 0 deletions src/runtime/relax_vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,11 @@ void VirtualMachineImpl::RunLoop() {

ObjectPtr<VirtualMachine> VirtualMachine::Create() { return make_object<VirtualMachineImpl>(); }

//----------------------------------------------------------------
// Profiler can be optionally disabled via a macro to reduce dep.
//----------------------------------------------------------------
#if TVM_RELAX_VM_ENABLE_PROFILER

/*!
* \brief An extension of VirtualMachineImpl to support per-op profiling
* It overrides RunInstrCall to add instrumentations around it.
Expand Down Expand Up @@ -927,6 +932,12 @@ ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() {
return make_object<VirtualMachineProfiler>();
}

#else
ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() {
LOG(FATAL) << "Profiler support is disabled";
return nullptr;
}
#endif // TVM_RELAX_VM_ENABLE_PROFILER
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
1 change: 1 addition & 0 deletions web/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ out
node_modules
build
debug
.ndarray_cache
Loading

0 comments on commit 678d01d

Please sign in to comment.