# Demo of optimization without weights in Relax

There are use cases that the model weights are not availble at compile-time. This notebook demonstrates how to perform the same optimizations that may change the model weights in Relax. The idea is to identify the transformations applied to the weights, lift and compile them into a separate `transform_params` function that will be shipped as part of the package for users to invoke in runtime with the real weights. 

There are several advantages of this approach:
* `transform_params` is called by the users after they load the real weights. Re-packaging or re-compilation is not needed.
* `transform_params` only need to be invoked once. There are no overhead during model benchmarking and inference.
* The same set of optimizations can be performed. The model will have the same performance as when the weights are provided.

## Example 
Let's start with a simple model that contains one layer of conv2d.
In this example, the weight of `conv2d`, `w1`, is a input tensor of the entry function `main` that is not available in the compile-time. We used the annotatoin `param_begin`, `param_end` to mark the range indices of the function inputs that are model parameters.

In [1]:
import tvm
from tvm import relax
from tvm.script import relax as R, tir as T
import numpy as np

target = tvm.target.Target("llvm")
dev = tvm.cpu(0)

@tvm.script.ir_module
class Module0:
    @R.function
    def main(x: R.Tensor((1, 3, 224, 224), "float32"), w1: R.Tensor((3, 16, 3, 3), "float32")) -> R.Tensor((1, 16, 224, 224), "float32"):
        R.func_attr({'param_begin': 1, 'param_end': 2})  # [begin, end] are range of indices of the function params
        with R.dataflow():
            conv1 =  R.nn.conv2d(x, w1, padding=(1, 1), data_layout="NCHW", kernel_layout="IOHW")
            R.output(conv1)
        return conv1


We will apply graph transformations. Take weight layout rewrite as an example, it will insert layout_transform functions to the IRModule.

In [2]:
@tvm.script.ir_module
class Module1:
    @T.prim_func
    def transform_layout_IOHW_to_OIHW(w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")) -> None:
        for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
            with T.block("layout_transform"):
                o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                out[o, i, h, w] = w1[i, o, h, w]

    @R.function
    def main(x: R.Tensor((1, 3, 224, 224), "float32"), w1: R.Tensor((3, 16, 3, 3), "float32")) -> R.Tensor((1, 16, 224, 224), "float32"):
        R.func_attr({'param_begin': 1, 'param_end': 2})  # annotate the tensors that are parameters, [begin, end] are range of indices of the function params
        with R.dataflow():
            w1_transformed = R.call_tir(transform_layout_IOHW_to_OIHW, w1, R.Tensor((16, 3, 3, 3), "float32"))  # this is weight transformation generated by passes like RewriteWeightLayout
            conv1 =  R.nn.conv2d(x, w1_transformed, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW")
            R.output(conv1)
        return conv1

After graph optimizations have been applied, we will invoke the `LiftTransformParams` pass to lift the transformations.

In [3]:
with tvm.transform.PassContext(opt_level=1):
    seq = tvm.transform.Sequential([relax.transform.LiftTransformParams()])
    mod = seq(Module1)
assert relax.analysis.well_formed(mod)

mod.show()

The new module contains `transform_params` function which takes a tuple of weights as the input and outputs a tuple of the optimized weights.

We will then run the usual compilation flow to build the Relax VM.

In [4]:
mod = relax.transform.LegalizeOps()(mod)
exec = relax.vm.build(mod, target, params=None)  # optimize and compile the model without params
vm = relax.vm.VirtualMachine(exec, dev)

With the built package, users can feed in the real weights to get the weights for the optimized model. (This step happens solely on the user side).

In [5]:
# these weights are only available at runtime
w1 = tvm.nd.array(np.random.uniform(size=(3, 16, 3, 3)).astype(dtype="float32"), dev)
params = tvm.runtime.container.tuple_object((w1,))
transformed_params = vm["transform_params"](params)

With the transformed parameters, users can run the model inference in the usual way.

In [6]:
x = tvm.nd.array(np.random.uniform(size=(1, 3, 224, 224)).astype(dtype="float32"), dev)
out = vm["main"](x, transformed_params)

In [7]:
# check the result
import tvm.topi.testing
ref_out = tvm.topi.testing.conv2d_nchw_python(x.numpy(), w1.numpy().transpose(1, 0, 2, 3), stride=1, padding=1)
tvm.testing.assert_allclose(ref_out, out.numpy(), atol=1e-4, rtol=1e-4)

## Integration with the end-to-end workflow
This approach can be integrated with the end-to-end workflow with the following changes needed:
- The frontend need to annotate the model inputs that are the weights. For example, when the weights are not provided, the frontend may produce a function with input list (x, w1, w2,), it need to detect `w1, w2` and annotate them as weights.
- The `LiftTransformParams` pass will be added to the compilation pipeline after graph optimization passes have been performed.
- The users will need to invoke `transform_params` API to get the weights for the optimized model before inference.