In [19]:
import tvm
from tvm import relay
from tvm.relay.prelude import Prelude
from tvm.relay.testing import layers
import numpy as np

In [20]:
def get_types(input_size, hidden_size, batch_size=1, dtype="float32"):
    input_type = relay.TensorType((batch_size, input_size), dtype)
    hidden_type = relay.TensorType((batch_size, hidden_size), dtype)
    i2h_weight_type = relay.TensorType((4 * hidden_size, input_size), dtype)
    h2h_weight_type = relay.TensorType((4 * hidden_size, hidden_size), dtype)
    bias_type = relay.TensorType((4 * hidden_size,), dtype)
    dense_type = relay.TensorType((batch_size, 4 * hidden_size), dtype)
    slice_type = relay.TupleType([hidden_type, hidden_type, hidden_type, hidden_type])
    state_type = relay.TupleType([hidden_type, hidden_type])
    return input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, dense_type, slice_type, state_type

In [21]:
def lstm_cell(input_size, hidden_size, batch_size=1, dtype="float32"): 
    input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, dense_type, \
        slice_type, state_type = get_types(input_size, hidden_size, batch_size, dtype)
    inputs = relay.Var("inputs", input_type)
    states = relay.Var("states", state_type)
    i2h_weight = relay.Var("i2h_weight", i2h_weight_type)
    i2h_bias = relay.Var("i2h_bias", bias_type)
    h2h_weight = relay.Var("h2h_weight", h2h_weight_type)
    h2h_bias = relay.Var("h2h_bias", bias_type)
    
    builder = relay.ScopeBuilder()
    old_h = builder.let(("old_h", hidden_type), relay.TupleGetItem(states, 0))
    old_c = builder.let(("old_c", hidden_type), relay.TupleGetItem(states, 1))
    i2h = builder.let(("i2h", dense_type),
                      layers.dense_add_bias(
                          data=inputs,
                          units=hidden_size * 4,
                          weight=i2h_weight, bias=i2h_bias,
                          name="i2h"))
    h2h = builder.let(("h2h", dense_type),
                      layers.dense_add_bias(
                          data=old_h,
                          units=hidden_size * 4,
                          weight=h2h_weight, bias=h2h_bias,
                          name="h2h"))
    gates = builder.let(("gates", dense_type), relay.add(i2h, h2h))
    slice_gates = builder.let(("slice_gates", slice_type),
                              relay.split(gates,
                                          indices_or_sections=4,
                                          axis=1).astuple())
    in_gate = builder.let(("in_gate", hidden_type),
                          relay.sigmoid(relay.TupleGetItem(slice_gates, 0)))
    forget_gate = builder.let(("forget_gate", hidden_type),
                              relay.sigmoid(relay.TupleGetItem(slice_gates, 1)))
    in_transform = builder.let(("in_transform", hidden_type),
                               relay.tanh(relay.TupleGetItem(slice_gates, 2)))
    out_gate = builder.let(("out_gate", hidden_type),
                           relay.sigmoid(relay.TupleGetItem(slice_gates, 3)))
    next_c = builder.let(("next_c", hidden_type),
                         relay.add(relay.multiply(forget_gate, old_c),
                                   relay.multiply(in_gate, in_transform)))
    next_h = builder.let(("next_h", input_type),
                         relay.multiply(out_gate, relay.tanh(next_c)))
    ret = builder.let(("ret", state_type), relay.Tuple([next_h, next_c]))
    builder.ret(ret)

    return relay.Function([inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias],
                          builder.get())

In [22]:
mod = relay.Module()
p = Prelude(mod)
l = p.l
nil = p.nil
cons = p.cons

In [23]:
def recursive_lstm(seq_len, input_size, hidden_size, batch_size=1, dtype="float32"):
    input_type, hidden_type, i2h_weight_type, h2h_weight_type, bias_type, _, \
        _, state_type = get_types(input_size, hidden_size, batch_size, dtype)
    
    input_seq = nil()
    for i in range(seq_len):
        inputs = relay.Var("data", input_type)
        input_seq = cons(inputs, input_seq)
    
    i2h_weight = relay.Var("i2h_weight", i2h_weight_type)
    i2h_bias = relay.Var("i2h_bias", bias_type)
    h2h_weight = relay.Var("h2h_weight", h2h_weight_type)
    h2h_bias = relay.Var("h2h_bias", bias_type)
    
    current_seq = relay.Var("current_seq", l(input_type))
    current_inputs = relay.Var("current_inputs", input_type)
    tail_seq = relay.Var("tail_seq", l(input_type))
    
    rec_fn = relay.Var("rec_fn")
    cell_fn = lstm_cell(input_size, hidden_size, batch_size, dtype)
    
    builder = relay.ScopeBuilder()
    zeros = builder.let(("zeros", hidden_type), relay.zeros((batch_size, hidden_size), dtype))
    init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros]))
    match = builder.let(("match", state_type), 
        relay.Match(
            current_seq,
            [relay.Clause(relay.PatternConstructor(nil), 
                          init_states),
             relay.Clause(relay.PatternConstructor(cons,
                                                   [relay.PatternVar(current_inputs), 
                                                    relay.PatternVar(tail_seq)]),
                          relay.Call(cell_fn, [current_inputs, rec_fn(tail_seq),
                                               i2h_weight, i2h_bias, h2h_weight, h2h_bias]))
            ]))
    builder.ret(match)
    func = relay.Function([current_seq], builder.get())
    ret = relay.Let(rec_fn, func, rec_fn(input_seq))
    out = relay.TupleGetItem(ret, 0)
    args = relay.analysis.free_vars(out)
    return relay.Function(args, out)

In [24]:
mod["main"] = recursive_lstm(2, 2, 2)
mod = relay.transform.InferType()(mod)
# mod = relay.transform.LambdaLift()(mod)
# print(mod["main"])
# print(mod["main"].params)

inputs = []
for v in mod["main"].params:
    t = v.checked_type
    rand_value = np.random.normal(size=t.concrete_shape).astype(t.dtype)
#     inputs.append(rand_value)
    inputs.append(tvm.nd.array(rand_value, ctx=tvm.cpu(0)))
# print(inputs)

# intrp = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
# res = intrp.evaluate()(*inputs)
# print(res)

ctx = tvm.cpu()
target = "llvm"
vm = relay.create_executor('debug', ctx=tvm.cpu(), target=target, mod=mod)
result = vm.evaluate()(*inputs)
print(result.asnumpy())

TVMError: Traceback (most recent call last):
  [bt] (8) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::CompileTreeNode(std::shared_ptr<tvm::relay::TreeNode<std::shared_ptr<tvm::relay::vm::ConditionNode> > >)+0x334) [0x7ff9d1aedc04]
  [bt] (7) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::CompileTreeNode(std::shared_ptr<tvm::relay::TreeNode<std::shared_ptr<tvm::relay::vm::ConditionNode> > >)+0x505) [0x7ff9d1aeddd5]
  [bt] (6) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::CompileTreeNode(std::shared_ptr<tvm::relay::TreeNode<std::shared_ptr<tvm::relay::vm::ConditionNode> > >)+0x505) [0x7ff9d1aeddd5]
  [bt] (5) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::CompileTreeNode(std::shared_ptr<tvm::relay::TreeNode<std::shared_ptr<tvm::relay::vm::ConditionNode> > >)+0xca) [0x7ff9d1aed99a]
  [bt] (4) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0x445) [0x7ff9d198dca5]
  [bt] (3) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::VisitExpr_(tvm::relay::LetNode const*)+0x3f) [0x7ff9d1aedf3f]
  [bt] (2) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&)+0x445) [0x7ff9d198dca5]
  [bt] (1) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::VisitExpr_(tvm::relay::CallNode const*)+0xaac) [0x7ff9d1af011c]
  [bt] (0) /home/ubuntu/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7ff9d1444352]
  File "/home/ubuntu/tvm/src/relay/backend/vm/compiler.cc", line 635
TVMError: internal error: unreachable code,should be transformed away by previous passesfree_var %tail_seq: List[Tensor[(1, 2), float32]]
free_var %i2h_weight: Tensor[(8, 2), float32]
free_var %i2h_bias: Tensor[(8), float32]
free_var %h2h_weight: Tensor[(8, 2), float32]
free_var %h2h_bias: Tensor[(8), float32]
%0 = @lifted_name7877187650838540133(%i2h_weight, %i2h_bias, %h2h_weight, %h2h_bias) /* ty=fn (List[Tensor[(1, 2), float32]]) -> (Tensor[(1, 2), float32], Tensor[(1, 2), float32]) */;
%0(%tail_seq) /* ty=(Tensor[(1, 2), float32], Tensor[(1, 2), float32]) */