In [32]:
from tvm import relay
from tvm.relay.prelude import Prelude
from tvm.relay.testing import layers, create_workload

In [33]:
def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""):
    builder = relay.ScopeBuilder()
    
    input_type = relay.TensorType((batch_size, num_hidden), dtype)
    weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
    bias_type = relay.TensorType((4*num_hidden,), dtype)
    
    dense_type = relay.TensorType((batch_size, 4*num_hidden), dtype)
    slice_type = relay.TupleType([input_type, input_type,
                                  input_type, input_type])
    ret_type = relay.TupleType([input_type,
                                relay.TupleType([input_type, input_type])])
    
    inputs = relay.Var("inputs", input_type)
    states = relay.Var("states", 
                       relay.TupleType([input_type, input_type]))
    
    i2h_weight = relay.Var("i2h_weight", weight_type)
    i2h_bias = relay.Var("i2h_bias", bias_type)
    
    h2h_weight = relay.Var("h2h_weight", weight_type)
    h2h_bias = relay.Var("h2h_bias", bias_type)
    
    i2h = builder.let(("i2h", dense_type),
                      layers.dense_add_bias(
                          data=inputs,
                          units=num_hidden * 4,
                          weight=i2h_weight, bias=i2h_bias,
                          name="%si2h" % name))
    h2h = builder.let(("h2h", dense_type),
                      layers.dense_add_bias(
                          data=relay.TupleGetItem(states, 0),
                          units=num_hidden * 4,
                          name="%sh2h" % name))
    gates = builder.let(("gates", dense_type), relay.add(i2h, h2h))
    slice_gates = builder.let(("slice_gates", slice_type),
                              relay.split(gates,
                                          indices_or_sections=4,
                                          axis=1).astuple())
    
    in_gate = builder.let(("in_gate", input_type),
                          relay.sigmoid(relay.TupleGetItem(slice_gates, 0)))
    forget_gate = builder.let(("forget_gate", input_type),
                              relay.sigmoid(relay.TupleGetItem(slice_gates, 1)))
    in_transform = builder.let(("in_transform", input_type),
                               relay.tanh(relay.TupleGetItem(slice_gates, 2)))
    out_gate = builder.let(("out_gate", input_type),
                           relay.sigmoid(relay.TupleGetItem(slice_gates, 3)))
    
    next_c = builder.let(("next_c", input_type),
                         relay.add(relay.multiply(forget_gate,
                                                  relay.TupleGetItem(states, 1)),
                                   relay.multiply(in_gate, in_transform)))
    next_h = builder.let(("next_h", input_type),
                         relay.multiply(out_gate, relay.tanh(next_c)))
    ret = builder.let(("ret", ret_type),
                      relay.Tuple([next_h, relay.Tuple([next_h, next_c])]))
    builder.ret(ret)

    body = builder.get()

    return relay.Function([inputs, states, i2h_weight,
                           i2h_bias, h2h_weight, h2h_bias],
                          body, ret_type)

In [34]:
def get_types(num_hidden, batch_size, dtype):
    input_type = relay.TensorType((batch_size, num_hidden), dtype)
    weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype)
    bias_type = relay.TensorType((4*num_hidden,), dtype)
    state_type = relay.TupleType([input_type, input_type])
    cell_type = relay.TupleType([input_type, state_type])
    return input_type, weight_type, bias_type, state_type, cell_type

In [35]:
def unroll_lstm(iterations, num_hidden, batch_size=1, dtype="float32"):
    input_type, weight_type, bias_type, state_type, cell_type = get_types(num_hidden, batch_size, dtype)
    
    builder = relay.ScopeBuilder()

    zeros = builder.let(("zeros", input_type),
                        relay.zeros((batch_size, num_hidden), dtype))
    init_states = builder.let(("init_states", state_type),
                              relay.Tuple([zeros, zeros]))

    states = init_states
    out = None

    for i in range(iterations):
        inputs = relay.Var("data", input_type)
        i2h_weight = relay.Var("i2h_%s_weight" % i, weight_type)
        i2h_bias = relay.Var("i2h_%i_bias" % i, bias_type)
        h2h_weight = relay.Var("h2h_%s_weight" % i, weight_type)
        h2h_bias = relay.Var("h2h_%s_bias" % i, bias_type)

        cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_%s" % i)

        call = builder.let(("call_%s" % i, cell_type),
                           relay.Call(cell_fn,
                                      [inputs, states, i2h_weight,
                                       i2h_bias, h2h_weight, h2h_bias]))
        new_out = builder.let(("out_%s" % i, input_type),
                              relay.TupleGetItem(call, 0))
        new_states = builder.let(("states_%s" % i, state_type),
                                 relay.TupleGetItem(call, 1))
        states = new_states
        out = new_out

    builder.ret(out)
    body = builder.get()
    args = relay.analysis.free_vars(body)
    return relay.Function(args, body, input_type)

In [36]:
mod = relay.Module()
p = Prelude(mod)
l = p.l
nil = p.nil
cons = p.cons
hd = p.hd
tl = p.tl
nth = p.nth
# x = relay.Var('x', l(relay.TensorType(shape=(), dtype='float32')))
# print(x)

In [37]:
def get_cell_builder(num_hidden, batch_size, dtype, args, states=None, previous_list=None, lstm_rec=None):
    input_type, weight_type, bias_type, state_type, cell_type = get_types(num_hidden, batch_size, dtype)
    cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_cell")
    builder = relay.ScopeBuilder()
    if states == None:
        assert(previous_list != None and lstm_rec != None)
        previous_res = builder.let(("previous_res", cell_type), lstm_rec(previous_list))
        states = builder.let(("states", state_type), relay.TupleGetItem(previous_res, 1))
    inputs = builder.let(("data", input_type), relay.TupleGetItem(args, 0))
    i2h_weight = builder.let(("i2h_weight", weight_type), relay.TupleGetItem(args, 1))
    i2h_bias = builder.let(("i2h_bias", bias_type), relay.TupleGetItem(args, 2))
    h2h_weight = builder.let(("h2h_weight", weight_type), relay.TupleGetItem(args, 3))
    h2h_bias = builder.let(("h2h_bias", bias_type), relay.TupleGetItem(args, 4))
    res = builder.let(("res", cell_type), 
                      relay.Call(cell_fn, [inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias]))
    builder.ret(res)
    return builder

In [38]:
def recursive_lstm(iterations, num_hidden, batch_size=1, dtype="float32"):
    input_type, weight_type, bias_type, state_type, cell_type = get_types(num_hidden, batch_size, dtype)
    args_type = relay.TupleType([input_type, weight_type, bias_type, weight_type, bias_type])
    
#     args_type = relay.TensorType((2, 2), "float32")
    
    args_list = nil()
    for i in range(iterations):
        inputs = relay.Var("data", input_type)
        i2h_weight = relay.Var("i2h_%s_weight" % i, weight_type)
        i2h_bias = relay.Var("i2h_%i_bias" % i, bias_type)
        h2h_weight = relay.Var("h2h_%s_weight" % i, weight_type)
        h2h_bias = relay.Var("h2h_%s_bias" % i, bias_type)
        args = relay.Tuple([inputs, i2h_weight, i2h_bias, h2h_weight, h2h_bias])
        args_list = cons(args, args_list)

#     for i in range(iterations):
#         args = relay.Var("data", args_type)
#         args_list = cons(args, args_list)
    
    builder = relay.scope_builder.ScopeBuilder()
    
    lstm_rec = relay.Var("lstm_rec")
    
    zeros = builder.let(("zeros", input_type), relay.zeros((batch_size, num_hidden), dtype))
    init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros]))
    first_args = relay.Var("first_args", args_type)
    first_br = get_cell_builder(num_hidden, batch_size, dtype, first_args, states=init_states)
    
    previous_list = relay.Var("previous_list", l(args_type))
    current_args = relay.Var("current_args", args_type)
    current_br = get_cell_builder(num_hidden, batch_size, dtype, current_args, 
                                  previous_list=previous_list, lstm_rec=lstm_rec)
    
    current_list = relay.Var("current_list", l(args_type))
    match = builder.let(("match", cell_type), 
        relay.Match(
            current_list,
            [relay.Clause(relay.PatternConstructor(cons, 
                                                   [relay.PatternVar(first_args), relay.PatternConstructor(nil)]), 
                          first_br.get()
                         ),
             relay.Clause(relay.PatternConstructor(cons, 
                                                   [relay.PatternVar(current_args), relay.PatternVar(previous_list)]), 
                          current_br.get()
                         )
            ],
            complete=False))
    
#     current_list = relay.Var("current_list", l(args_type))
#     match = builder.let(("match", relay.TensorType((), "int32")), 
#         relay.Match(
#             current_list,
#             [relay.Clause(relay.PatternConstructor(cons, 
#                                                    [relay.PatternVar(first_args), relay.PatternConstructor(nil)]), 
#                           relay.const(1, dtype="int32")
#                          ),
#              relay.Clause(relay.PatternConstructor(cons, 
#                                                    [relay.PatternVar(current_args), relay.PatternVar(previous_list)]), 
#                           relay.const(1, dtype="int32") + lstm_rec(previous_list)
#                          )
#             ],
#             complete=False))

    builder.ret(match)
    func = relay.Function([current_list], builder.get())
    ret = relay.Let(lstm_rec, func, lstm_rec(args_list))
    out = relay.TupleGetItem(ret, 0)
    args = relay.analysis.free_vars(out)
    
#     builder.ret(match)
#     func = relay.Function([current_list], builder.get())
#     ret = relay.Let(lstm_rec, func, lstm_rec(args_list))
#     args = relay.analysis.free_vars(ret)

    return relay.Function(args, out)

In [39]:
# ll = mod.get_global_type_var("List")
# print(ll)
# print(mod.get_global_type_vars())
mod["main"] = recursive_lstm(2, 2)
print(mod["main"])
mod = relay.transform.InferType()(mod)
print(mod["main"].params)

v0.0.4
fn (%lstm_cellh2h_weight: Tensor[(8, 2), float32], %lstm_cellh2h_bias: Tensor[(8), float32], %lstm_cellh2h_weight1: Tensor[(8, 2), float32], %lstm_cellh2h_bias1: Tensor[(8), float32], %data: Tensor[(1, 2), float32], %i2h_1_weight: Tensor[(8, 2), float32], %i2h_1_bias: Tensor[(8), float32], %h2h_1_weight: Tensor[(8, 2), float32], %h2h_1_bias: Tensor[(8), float32], %data1: Tensor[(1, 2), float32], %i2h_0_weight: Tensor[(8, 2), float32], %i2h_0_bias: Tensor[(8), float32], %h2h_0_weight: Tensor[(8, 2), float32], %h2h_0_bias: Tensor[(8), float32]) -> Tensor[(1, 2), float32] {
  %31 = (
    let %lstm_rec-malformed-ir = fn (%current_list: List[(Tensor[(1, 2), float32], Tensor[(8, 2), float32], Tensor[(8), float32], Tensor[(8, 2), float32], Tensor[(8), float32])]) -> (Tensor[(1, 2), float32], (Tensor[(1, 2), float32], Tensor[(1, 2), float32])) {
      let %zeros: Tensor[(1, 2), float32] = zeros(shape=[1, 2], dtype="float32") /* ty=Tensor[(1, 2), float32] */;
      let %init_states: (Ten

In [28]:
def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"):
    net =  unroll_lstm(iterations, num_hidden, batch_size, dtype)
    return create_workload(net)

In [29]:
iterations = 2
num_hidden = 2
net, params = get_workload(iterations, num_hidden)

In [31]:
print(net)
print(net["main"].params)
print(params)

v0.0.4
def @main(%lstm_0h2h_weight: Tensor[(8, 2), float32], %lstm_0h2h_bias: Tensor[(8), float32], %data: Tensor[(1, 2), float32], %i2h_0_weight: Tensor[(8, 2), float32], %i2h_0_bias: Tensor[(8), float32], %h2h_0_weight: Tensor[(8, 2), float32], %h2h_0_bias: Tensor[(8), float32], %lstm_1h2h_weight: Tensor[(8, 2), float32], %lstm_1h2h_bias: Tensor[(8), float32], %data1: Tensor[(1, 2), float32], %i2h_1_weight: Tensor[(8, 2), float32], %i2h_1_bias: Tensor[(8), float32], %h2h_1_weight: Tensor[(8, 2), float32], %h2h_1_bias: Tensor[(8), float32]) -> Tensor[(1, 2), float32] {
  let %zeros: Tensor[(1, 2), float32] = zeros(shape=[1, 2], dtype="float32") /* ty=Tensor[(1, 2), float32] */;
  let %init_states: (Tensor[(1, 2), float32], Tensor[(1, 2), float32]) = (%zeros, %zeros);
  %12 = fn (%inputs: Tensor[(1, 2), float32], %states: (Tensor[(1, 2), float32], Tensor[(1, 2), float32]), %i2h_weight: Tensor[(8, 2), float32], %i2h_bias: Tensor[(8), float32], %h2h_weight: Tensor[(8, 2), float32], %h2h_

In [1]:
print(net)

NameError: name 'net' is not defined

In [11]:
a = 1
b = 2
assert(a == None and b != None)

AssertionError: 