In [1]:
import tvm
from tvm import relay
from tvm.relay.prelude import Prelude
from tvm.relay.testing import layers
import numpy as np

In [2]:
def get_types(input_size, hidden_size, batch_size=1, dtype="float32"):
    input_type = relay.TensorType((batch_size, input_size), dtype)
    hidden_type = relay.TensorType((batch_size, hidden_size), dtype)
    i2h_weight_type = relay.TensorType((4 * hidden_size, input_size), dtype)
    h2h_weight_type = relay.TensorType((4 * hidden_size, hidden_size), dtype)
    bias_type = relay.TensorType((4 * hidden_size,), dtype)
    dense_type = relay.TensorType((batch_size, 4 * hidden_size), dtype)
    slice_type = relay.TupleType([hidden_type, hidden_type, hidden_type, hidden_type])
    state_type = relay.TupleType([hidden_type, hidden_type])
    return input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, dense_type, slice_type, state_type

In [3]:
def lstm_cell(input_size, hidden_size, batch_size=1, dtype="float32"): 
    input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, dense_type, \
        slice_type, state_type = get_types(input_size, hidden_size, batch_size, dtype)
    inputs = relay.Var("inputs", input_type)
    states = relay.Var("states", state_type)
    i2h_weight = relay.Var("i2h_weight", i2h_weight_type)
    i2h_bias = relay.Var("i2h_bias", bias_type)
    h2h_weight = relay.Var("h2h_weight", h2h_weight_type)
    h2h_bias = relay.Var("h2h_bias", bias_type)
    
    builder = relay.ScopeBuilder()
    old_h = builder.let(("old_h", hidden_type), relay.TupleGetItem(states, 0))
    old_c = builder.let(("old_c", hidden_type), relay.TupleGetItem(states, 1))
    i2h = builder.let(("i2h", dense_type),
                      layers.dense_add_bias(
                          data=inputs,
                          units=hidden_size * 4,
                          weight=i2h_weight, bias=i2h_bias,
                          name="i2h"))
    h2h = builder.let(("h2h", dense_type),
                      layers.dense_add_bias(
                          data=old_h,
                          units=hidden_size * 4,
                          weight=h2h_weight, bias=h2h_bias,
                          name="h2h"))
    gates = builder.let(("gates", dense_type), relay.add(i2h, h2h))
    slice_gates = builder.let(("slice_gates", slice_type),
                              relay.split(gates,
                                          indices_or_sections=4,
                                          axis=1).astuple())
    in_gate = builder.let(("in_gate", hidden_type),
                          relay.sigmoid(relay.TupleGetItem(slice_gates, 0)))
    forget_gate = builder.let(("forget_gate", hidden_type),
                              relay.sigmoid(relay.TupleGetItem(slice_gates, 1)))
    in_transform = builder.let(("in_transform", hidden_type),
                               relay.tanh(relay.TupleGetItem(slice_gates, 2)))
    out_gate = builder.let(("out_gate", hidden_type),
                           relay.sigmoid(relay.TupleGetItem(slice_gates, 3)))
    next_c = builder.let(("next_c", hidden_type),
                         relay.add(relay.multiply(forget_gate, old_c),
                                   relay.multiply(in_gate, in_transform)))
    next_h = builder.let(("next_h", input_type),
                         relay.multiply(out_gate, relay.tanh(next_c)))
    ret = builder.let(("ret", state_type), relay.Tuple([next_h, next_c]))
    builder.ret(ret)

    return relay.Function([inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias],
                          builder.get())

In [4]:
def unroll_lstm(seq_len, input_size, hidden_size, batch_size=1, dtype="float32"):
    input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, _, \
        _, state_type = get_types(input_size, hidden_size, batch_size, dtype)
    
    i2h_weight = relay.Var("i2h_weight", i2h_weight_type)
    i2h_bias = relay.Var("i2h_bias", bias_type)
    h2h_weight = relay.Var("h2h_weight", h2h_weight_type)
    h2h_bias = relay.Var("h2h_bias", bias_type)
    
    cell_fn = lstm_cell(input_size, hidden_size, batch_size, dtype)
    
    builder = relay.ScopeBuilder()
    zeros = builder.let(("zeros", hidden_type), relay.zeros((batch_size, hidden_size), dtype))
    states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros]))
    
    for i in range(seq_len):
        inputs = relay.Var("data_" + str(i), input_type)
        new_states = builder.let(("call", state_type),
                                 relay.Call(cell_fn, [inputs, states, i2h_weight, 
                                                      i2h_bias, h2h_weight, h2h_bias]))
        states = new_states
    
    out = builder.let(("out", hidden_type), relay.TupleGetItem(states, 0))
    builder.ret(out)
    body = builder.get()
    args = relay.analysis.free_vars(body)
    return relay.Function(args, body)

In [5]:
mod = relay.Module()
mod["main"] = unroll_lstm(2, 2, 2)
mod = relay.transform.InferType()(mod)
print(mod["main"])
print(mod["main"].params)
   
shape_dict = {
    v.name_hint : v.checked_type for v in mod["main"].params}
np.random.seed(0)
initializer = relay.testing.init.Xavier()
params = {}
for k, v in shape_dict.items():
    if k.startswith("data"):
        continue
    init_value = np.zeros(v.concrete_shape).astype(v.dtype)
    initializer(k, init_value)
    params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
print(params)

v0.0.4
fn (%data_0: Tensor[(1, 2), float32], %i2h_weight: Tensor[(8, 2), float32], %i2h_bias: Tensor[(8), float32], %h2h_weight: Tensor[(8, 2), float32], %h2h_bias: Tensor[(8), float32], %data_1: Tensor[(1, 2), float32]) -> Tensor[(1, 2), float32] {
  let %zeros: Tensor[(1, 2), float32] = zeros(shape=[1, 2], dtype="float32") /* ty=Tensor[(1, 2), float32] */;
  let %init_states: (Tensor[(1, 2), float32], Tensor[(1, 2), float32]) = (%zeros, %zeros);
  %9 = fn (%inputs: Tensor[(1, 2), float32], %states: (Tensor[(1, 2), float32], Tensor[(1, 2), float32]), %i2h_weight1: Tensor[(8, 2), float32], %i2h_bias1: Tensor[(8), float32], %h2h_weight1: Tensor[(8, 2), float32], %h2h_bias1: Tensor[(8), float32]) -> (Tensor[(1, 2), float32], Tensor[(1, 2), float32]) {
    let %old_h: Tensor[(1, 2), float32] = %states.0;
    let %old_c: Tensor[(1, 2), float32] = %states.1;
    %0 = nn.dense(%inputs, %i2h_weight1, units=8) /* ty=Tensor[(1, 8), float32] */;
    let %i2h: Tensor[(1, 8), float32] = nn.bias_ad

In [7]:
opt_level = 1
target = "llvm"
with relay.build_config(opt_level=opt_level):
    graph, lib, params = relay.build_module.build(
        mod, target, params=params)

ctx = tvm.cpu()
data_dict = {}
seq_len = 2
for i in range(seq_len):
    k = "data_" + str(i)
    v = shape_dict[k]
    data_dict[k] = np.random.uniform(
        -1, 1, size=v.concrete_shape).astype(v.dtype)
module = graph_runtime.create(graph, lib, ctx)
for k, v in data_dict:
    module.set_input(k, v)
module.set_input(**params)
module.run()
outshape = (1, 2)
out = module.get_output(0, tvm.nd.empty(out_shape)).asnumpy()
print(out)

TVMError: Traceback (most recent call last):
  [bt] (8) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x47f) [0x7f5bb4d3464f]
  [bt] (7) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::LetNode const*)+0x18b) [0x7f5bb4d33c4b]
  [bt] (6) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x47f) [0x7f5bb4d3464f]
  [bt] (5) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::LetNode const*)+0x18b) [0x7f5bb4d33c4b]
  [bt] (4) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x47f) [0x7f5bb4d3464f]
  [bt] (3) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::LetNode const*)+0x147) [0x7f5bb4d33c07]
  [bt] (2) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x42d) [0x7f5bb4d345fd]
  [bt] (1) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::CallNode const*)+0xd92) [0x7f5bb4d3a062]
  [bt] (0) /home/ubuntu/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f5bb46c2352]
  File "/home/ubuntu/tvm/src/relay/backend/graph_runtime_codegen.cc", line 397
TVMError: TVM only support calls to primitive functions (i.e functions composed of fusable operator invocations)

In [4]:
mod = relay.Module()
p = Prelude(mod)
l = p.l
nil = p.nil
cons = p.cons

In [8]:
def recursive_lstm(seq_len, input_size, hidden_size, batch_size=1, dtype="float32"):
    input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, _, \
        _, state_type = get_types(input_size, hidden_size, batch_size, dtype)
    
    input_seq = nil()
    for i in range(seq_len):
        inputs = relay.Var("data", input_type)
        input_seq = cons(inputs, input_seq)
    
    i2h_weight = relay.Var("i2h_weight", i2h_weight_type)
    i2h_bias = relay.Var("i2h_bias", bias_type)
    h2h_weight = relay.Var("h2h_weight", h2h_weight_type)
    h2h_bias = relay.Var("h2h_bias", bias_type)
    
    current_seq = relay.Var("current_seq", l(input_type))
    current_inputs = relay.Var("current_inputs", input_type)
    tail_seq = relay.Var("tail_seq", l(input_type))
    
    rec_fn = relay.Var("rec_fn")
    cell_fn = lstm_cell(input_size, hidden_size, batch_size, dtype)
    
    builder = relay.ScopeBuilder()
    zeros = builder.let(("zeros", hidden_type), relay.zeros((batch_size, hidden_size), dtype))
    init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros]))
    match = builder.let(("match", state_type), 
        relay.Match(
            current_seq,
            [relay.Clause(relay.PatternConstructor(nil), 
                          init_states),
             relay.Clause(relay.PatternConstructor(cons,
                                                   [relay.PatternVar(current_inputs), 
                                                    relay.PatternVar(tail_seq)]),
                          relay.Call(cell_fn, [current_inputs, rec_fn(tail_seq),
                                               i2h_weight, i2h_bias, h2h_weight, h2h_bias]))
            ]))
    builder.ret(match)
    func = relay.Function([current_seq], builder.get())
    ret = relay.Let(rec_fn, func, rec_fn(input_seq))
    out = relay.TupleGetItem(ret, 0)
    args = relay.analysis.free_vars(out)
    return relay.Function(args, out)

In [9]:
mod["main"] = recursive_lstm(2, 2, 2)
mod = relay.transform.InferType()(mod)
print(mod["main"])
print(mod["main"].params)

NameError: name 'nil' is not defined

In [6]:
print(mod["main"])

v0.0.4
fn (%data: Tensor[(1, 2), float32], %i2h_weight: Tensor[(8, 2), float32], %i2h_bias: Tensor[(8), float32], %h2h_weight: Tensor[(8, 2), float32], %h2h_bias: Tensor[(8), float32], %data1: Tensor[(1, 2), float32]) -> Tensor[(1, 2), float32] {
  let %zeros: Tensor[(1, 2), float32] = zeros(shape=[1, 2], dtype="float32") /* ty=Tensor[(1, 2), float32] */;
  let %init_states: (Tensor[(1, 2), float32], Tensor[(1, 2), float32]) = (%zeros, %zeros);
  %9 = fn (%inputs: Tensor[(1, 2), float32], %states: (Tensor[(1, 2), float32], Tensor[(1, 2), float32]), %i2h_weight1: Tensor[(8, 2), float32], %i2h_bias1: Tensor[(8), float32], %h2h_weight1: Tensor[(8, 2), float32], %h2h_bias1: Tensor[(8), float32]) -> (Tensor[(1, 2), float32], Tensor[(1, 2), float32]) {
    let %old_h: Tensor[(1, 2), float32] = %states.0;
    let %old_c: Tensor[(1, 2), float32] = %states.1;
    %0 = nn.dense(%inputs, %i2h_weight1, units=8) /* ty=Tensor[(1, 8), float32] */;
    let %i2h: Tensor[(1, 8), float32] = nn.bias_add(%