# Bring Your Own Datatypes to TVM

In this tutorial, we will show you how you can use your own custom datatypes in TVM, utilizing TVM's Bring Your Own Datatypes framework.
Note that the Bring Your Own Datatypes framework currently only handles **software emulated versions of datatypes** right now, which is what we'll be discussing in this tutorial.

## Datatype Libraries

The central idea of the Bring Your Own Datatypes framework is to make datatypes in TVM more abstract, so that users can bring their own software-emulated datatype implementations.

In the wild, these datatype implementations usually appear as libraries. For example:
- [libposit](https://github.com/cjdelisle/libposit), a posit library
- [Stillwater Universal](https://github.com/stillwater-sc/universal), a library with posits, fixed-point numbers, and other types
- [SoftFloat](https://github.com/ucb-bar/berkeley-softfloat-3), Berkeley's software implementation of IEEE 754

The Bring Your Own Datatypes framework allows libraries such as these to be easily plugged in to TVM!

In this section, we will explore our example library, [libposit](https://github.com/cjdelisle/libposit).
**Posits** are a datatype developed to compete with IEEE 754 floating point numbers.
We won't go into much detail about the datatype itself.
If you'd like to learn more, read through John Gustafson's [Beating Floating Point at its Own Game](https://posithub.org/docs/BeatingFloatingPoint.pdf).

First, let's clone and build the library. Note that we're using my branch of libposit which includes some fixes.

In [None]:
! git clone https://github.com/uwsampl/libposit
! cd libposit && git checkout 7c1788f291c1b5f74ded9acf4ffae7911c2df28c && autoreconf -f -i && ./configure && make

The library contains operations over posits.
Let's see what we can do with 16-bit posits (which are much like 16-bit floats):

In [None]:
! nm libposit/libposit.a | grep posit16

We see many operations that we might expect for datatypes: `posit16_abs` for absolute value, `posit16_add` for addition, `posit16_cmp` for comparsion, etc.

In [None]:
libposit_wrapper_source = """
#include "posit.h"
"""
! echo {libposit_wrapper_source} > libposit-wrapper.cc
! cc -Ilibposit/generated -lmpfr -lgmp --std=c++14 -shared -o libposit-wrapper.so -fPIC libposit-wrapper.cc libposit/libposit.a
! ls -alh libposit-wrapper.so

## A Simple TVM Program
We'll begin by writing a simple program in TVM; afterwards, we will re-write it to use custom datatypes.

In [None]:
import tvm

# Our basic program: Z = X + Y
X = tvm.placeholder((3, ))
Y = tvm.placeholder((3, ))
Z = X + Y

Next, we compile for LLVM. The process of compiling in TVM is broken into scheduling, lowering, and finally, building:

In [None]:
target = "llvm"
schedule = tvm.create_schedule([Z.op])
lowered_func = tvm.lower(schedule, [X, Y, Z])
built_program = tvm.build(lowered_func, target=target)

# Print the lowered IR (simple mode makes it cleaner to read!)
print(tvm.lower(schedule, [X, Y, Z], simple_mode=True))

Now, we create random inputs to feed into this program using `numpy`.

In [None]:
import numpy as np

# Create a device context
context = tvm.context(target, 0)

# Create random input arrays on the above context
x = tvm.nd.array(np.array([1.0, 1.0, 1.0]).astype("float32"), ctx=context)
y = tvm.nd.array(np.array([1.333, -0.75, 10.9]).astype("float32"), ctx=context)
print("x: {}".format(x))
print("y: {}".format(y))


This empty array will hold our output.

In [None]:
z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=context)

Finally, we're ready to run the program:

In [None]:
built_program(x, y, z)
print("z: {}".format(z))

## Interlude: `bfloat16`

Before we rewrite our program using custom datatypes, let's introduce the custom datatype we will use: the `bfloat16`. `bfloat16` is a very straightforward datatype; it is simply a 32-bit IEEE float chopped in half! Specifically, the 16 least-significant bits of the fraction are chopped off. The result is a format which
- is straightforward to convert to and from 32-bit IEEE float
- has the same dynamic range as a 32-bit IEEE float, but with less precision
- takes up half the space!

The `bfloat16` is built in to TensorFlow, and used natively on deep learning hardware (such as the TPU). Training deep learning models with the `bfloat16` often results in the same converged accuracy, [according to TensorFlow docs!](https://cloud.google.com/tpu/docs/bfloat16)

TVM has a toy `bfloat16` library built-in for testing and demonstration purposes at [3rdparty/bfloat16/bfloat16.cc](https://github.com/dmlc/tvm/blob/master/3rdparty/bfloat16/bfloat16.cc). The `float->bfloat16` and `bfloat16->float` functions are taken from TensorFlow, while the other functions simply convert to `float` and use the native implementations of the functions they implement. Thus, it is not a true `bfloat16` implementation, but serves perfectly well for demonstration.

## Adding Custom Datatypes

Now, we will do the same, but we will use a custom datatype for our intermediate computation.

We use the same input placeholders `X` and `Y` as above, but before adding `X + Y`, we first cast both `X` and `Y` to a custom datatype via the `topi.cast(...)` call.

Note how we specify the custom datatype: we indicate it using the special `custom[...]` syntax. Additionally, note the "16" after the datatype: this is the bitwidth of the custom datatype. This tells TVM that each instance of `bfloat` is 16 bits wide.

In [None]:
try:
    Z = topi.cast(
        topi.cast(X, dtype="custom[bfloat]16") +
        topi.cast(Y, dtype="custom[bfloat]16"),
        dtype="float32")
except tvm.TVMError as e:
    # Print last line of error
    print(str(e).split('\n')[-1])

Trying to generate this program throws an error from TVM:
`TVMError: Check failed: name_to_code_.find(type_name) != name_to_code_.end(): Type name bfloat not registered`.
Unsurprisingly, TVM does not know how to handle any custom datatype out of the box. We first have to register the custom type with TVM, giving it a name and a type code:

In [None]:
tvm.datatype.register("bfloat", 129)

Note that the type code, 129, is currently chosen manually by the programmer. See `TVMTypeCode::kCustomBegin` in [include/tvm/runtime/c_runtime_api.h](https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/c_runtime_api.h).

Now we can generate our program again:

In [None]:
Z = topi.cast(
    topi.cast(X, dtype="custom[bfloat]16") +
    topi.cast(Y, dtype="custom[bfloat]16"),
    dtype="float32")

Next, we again compile our program by scheduling, lowering, and building.

Note that we currently have to manually lower custom datatypes via the `tvm.ir_pass.LowerCustomDatatypes(...)` call. This is simply because we have not incorporated the custom datatypes lowering pass into the primary TVM build passes. Once custom datatype lowering is incorporated into these passes, we will not need to do this manually.

In [None]:
try:
    schedule = tvm.create_schedule([Z.op])
    lowered_func = tvm.lower(schedule, [X, Y, Z])
    lowered_func = tvm.ir_pass.LowerCustomDatatypes(lowered_func, target)
    built_program = tvm.build(lowered_func, target=target)
except tvm.TVMError as e:
    # Print last line of error
    print(str(e).split('\n')[-1])

Now, trying to compile this program throws an error:
`TVMError: Check failed: lower: Cast lowering function for target llvm destination type 129 source type 2 not found`.
Let's dissect this error.

The error is occurring during our `LowerCustomDatatypes(...)` call. TVM is telling us that it cannot find a _lowering function_ for the `Cast` operation, when casting from source type 2 (`float`, in TVM), to destination type 129 (our custom datatype). When lowering custom datatypes, if TVM encounters an operation over a custom datatype, it looks for a user-registered _lowering function_, which tells it how to lower the operation to an operation over datatypes it understands. We have not told TVM how to lower `Cast` operations for our custom datatypes; thus, the source of this error.

To fix this error, we simply need to specify a lowering function:

In [None]:
tvm.datatype.register_op(tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"),
                         "Cast", target, "bfloat", "float")

The `register_op(...)` call takes a lowering function, and a number of parameters which specify exactly the operation which should be lowered with the provided lowering function. In this case, the arguments we pass specify that this lowering function is for lowering a `Cast` from `float` to `bfloat` for target `"llvm"`.

The lowering function passed into this call is very general: it should take an operation of the specified type (in this case, `Cast`) and return another operation which only uses datatypes which TVM understands.

In the general case, we expect users to implement operations over their custom datatypes using calls to an external library. In our example, our `bfloat16` library (which, remember, is built into TVM) implements a `Cast` from `float` to `bfloat` in the function `FloatToBFloat16_wrapper`. To provide for the general case, we have made a helper function, `create_lower_func(...)`, which does just this: given a function name, it replaces the given operation with a `Call` to the function name provided. It additionally removes usages of the custom datatype by storing the custom datatype in an opaque `uint` of the appropriate width; in our case, a `uint16_t`.

We can now re-try our build:

In [None]:
try:
    schedule = tvm.create_schedule([Z.op])
    lowered_func = tvm.lower(schedule, [X, Y, Z])
    lowered_func = tvm.ir_pass.LowerCustomDatatypes(lowered_func, target)
    built_program = tvm.build(lowered_func, target=target)
except tvm.TVMError as e:
    # Print last line of error
    print(str(e).split('\n')[-1])

This new error tells us that the `Add` lowering function is not found, which is good news, as it's no longer complaining about the `Cast`! We know what to do from here: we just need to register the lowering functions for the other operations in our program.

In [None]:
tvm.datatype.register_op(tvm.datatype.create_lower_func("BFloat16ToFloat_wrapper"),
                         "Cast", target, "float", "bfloat")
tvm.datatype.register_op(tvm.datatype.create_lower_func("BFloat16Add_wrapper"),
                         "Add", target, "bfloat")

Now, we can build our program without errors.

In [None]:
schedule = tvm.create_schedule([Z.op])
lowered_func = tvm.lower(schedule, [X, Y, Z])
lowered_func = tvm.ir_pass.LowerCustomDatatypes(lowered_func, target)
built_program = tvm.build(lowered_func, target=target)

print(lowered_func.body)

You can see that the IR contains our program, and implements the program using calls to our library.

Finally, we'll run the resulting program.

In [None]:
z_bfloat = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=context)
built_program(x, y, z_bfloat)
print("z_bfloat: {}".format(z_bfloat))

**NOTE:** The external library functions implementing your datatype (e.g. `FloatToBFloat16_wrapper`, in our example) must be loaded into the process space and visible for lookup at runtime. In our example, this happens automatically, as the library is built into the TVM shared library object. However, in other cases, you can use `CDLL` to load your library in global mode:

In [None]:
# import ctypes
# ctypes.CDLL(library_name, ctypes.RTLD_GLOBAL) 

We can now look at the results of the two programs side-by-side:

In [None]:
print("x:\t\t{}".format(x))
print("y:\t\t{}".format(y))
print("z:\t\t{}".format(z))
print("z_bfloat:\t{}".format(z_bfloat))

Perhaps as expected, the `bfloat16` results are very close to the `float` results, but with some loss in precision!

## Preview: Running Models With Custom Datatypes

**Note:** All code previous to this point will work if you build TVM's `master` branch. This section of the notebook, however, uses code which is not yet merged into mainline TVM. This code still has its bugs (specifically, some numerical stability issues with posit softmax), and is moreso meant to demonstrate at a high level how you will be able to change a model to a custom datatype.

In this final section of the notebook, we will demo additions to the Bring Your Own Datatypes framework which make it easy to run entire models using custom datatypes.

### Interlude 2: Posits

[Posits](https://posithub.org/docs/Posits4.pdf) are a new numerical datatype developed by John Gustafson; see [posithub.org](https://posithub.org/) for a central repository of all information about posits. Posits encode real numbers in a familiar way, using a sign, exponent, and fraction field. However, posits also add an additional _regime_ field, which provide an additional scaling factor on the number. There are a number of interesting practical results of this. A few highlights are:
- Posits have greater precision near ±1, and less precision at very large values 
- Posits represent more numbers around ±1, and less numbers at very large values

Both of these features make posits very attractive for deep learning! [Deep Positron](https://arxiv.org/abs/1812.01762) is an example of posit research in the deep learning space, demonstrating promising results when using posits instead of floats.

### Converting Models to Custom Datatypes

We will first choose the model which we would like to run with posits. In this case we use [Mobilenet](https://arxiv.org/abs/1704.04861). We choose Mobilenet due to its small size. In this alpha state of the Bring Your Own Datatypes framework, we have not implemented any software optimizations for running software emulations of custom datatypes; the result is poor performance due to many calls into our datatype emulation library.

Relay has packaged up many models within its [python/tvm/relay/testing](https://github.com/dmlc/tvm/tree/master/python/tvm/relay/testing) directory. We will go ahead and grab Mobilenet:

In [None]:
from tvm.relay.testing.mobilenet import get_workload as get_mobilenet

module, params = get_mobilenet()

We can execute Mobilenet easily using the Relay graph execution engine:

In [None]:
ex = tvm.relay.create_executor("graph", mod=module)
input = tvm.nd.array(np.random.rand(3, 224, 224).astype("float32"))
result = ex.evaluate()(input, **params)
print(result)

Now, we would like to change the model to use posits internally. To do so, we first must register posits:

In [None]:
tvm.datatype.register("posit", 130)

Next, we need to convert the network. To do this, we first define a function which will help us convert tensors:


In [None]:
def convert_ndarray(dst_dtype, array, executor):
    x = tvm.relay.var('x', shape=array.shape, dtype=str(array.dtype))
    cast = tvm.relay.Function([x], x.astype(dst_dtype))
    return executor.evaluate(cast)(array)

Now, to actually convert the entire network, we have written [a pass in Relay](https://github.com/gussmith23/tvm/blob/ea174c01c54a2529e19ca71e125f5884e728da6e/python/tvm/relay/frontend/change_datatype.py#L21) which simply converts all nodes within the model to use the new datatype.

In [None]:
from tvm.relay.frontend.change_datatype import ChangeDatatype

src_dtype = 'float32'                                                                                                                                                                
dst_dtype = 'custom[posit]16'

# Currently, custom datatypes only work if you run simplify_inference beforehand
module = tvm.relay.transform.SimplifyInference()(module)

# Run type inference before changing datatype
module = tvm.relay.transform.InferType()(module)

# Change datatype from float to posit and re-infer types
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module['main'])
module = tvm.relay.transform.InferType()(tvm.relay.Module.from_expr(expr))

# Finally, try to convert the parameters:
try:
  params = dict(
      (p, convert_ndarray(dst_dtype, params[p], ex)) for p in params)
except tvm.TVMError as e:
    # Print last line of error
    print(str(e).split('\n')[-1])

When we attempt to convert the parameters, we get a familiar error. We need to implement our posit functions! 

Because this is a neural network, many more operations are required. We have implemented these operations in a small posit library at [3rdparty/bfloat16/bfloat16.cc](https://github.com/gussmith23/tvm/blob/a335e112dfa36fb7b460619401264ddb90007f55/3rdparty/bfloat16/bfloat16.cc). This small library depends on [Stillwater Supercomputing's Universal library](https://github.com/stillwater-sc/universal), which implements posits (and other _universal numbers_) in great detail.

Here, we register all the needed functions:

In [None]:
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("FloatToPosit16es1"), "Cast",
    "llvm", "posit", "float")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1ToFloat"), "Cast",
    "llvm", "float", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("IntToPosit16es1"), "Cast",
    "llvm", "posit", "int")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Add"), "Add",
    "llvm", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Sub"), "Sub",
    "llvm", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("FloatToPosit16es1"),
    "FloatImm", "llvm", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Mul"), "Mul",
    "llvm", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Div"), "Div",
    "llvm", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Max"), "Max",
    "llvm", "posit")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Sqrt"),
    "Call",
    "llvm",
    "posit",
    intrinsic_name="sqrt")
# TODO(gus) not sure if this will work...
tvm.datatype.register_op(
    tvm.datatype.lower_ite,
    "Call",
    "llvm",
    "posit",
    intrinsic_name="tvm_if_then_else")
tvm.datatype.register_op(
    tvm.datatype.create_lower_func("Posit16es1Exp"),
    "Call",
    "llvm",
    "posit",
    intrinsic_name="exp")

Now, we can convert our params:

In [None]:
params = dict(
    (p, convert_ndarray(dst_dtype, params[p], ex)) for p in params)

We also need to convert our input:

In [None]:
input = convert_ndarray(dst_dtype, input, ex)

Finally, we can run the converted model:

In [None]:
# Vectorization is not implemented with custom datatypes.
with tvm.build_config(disable_vectorize=True):
  result_posit = ex.evaluate(expr)(input, **params)
  # Print the result! (first, convert back to float)
  print(convert_ndarray(src_dtype, result_posit, ex))