diff --git a/core/conversion/converters/impl/lstm_cell.cpp b/core/conversion/converters/impl/lstm_cell.cpp index edb8388dd0..1b205937da 100755 --- a/core/conversion/converters/impl/lstm_cell.cpp +++ b/core/conversion/converters/impl/lstm_cell.cpp @@ -14,6 +14,29 @@ namespace converters { namespace impl { namespace { +nvinfer1::ITensor* add_bias(nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::string b_name, ConversionCtx* ctx, const torch::jit::Node* n) { + auto a_dim = a->getDimensions(); + auto b_dim = b->getDimensions(); + + LOG_DEBUG(b_name << " tensor shape: " << b_dim); + + TRTORCH_CHECK(util::broadcastable(a_dim, b_dim, false), "bias " << b_name << " is not broadcastable - can't be added to previous matmul operation."); + + if (util::toVec(a_dim) != util::toVec(b_dim)) { + LOG_DEBUG(b_name << "'s dimensions need to be reshaped"); + + auto shuffle = ctx->net->addShuffle(*b); + TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_dim), a_dim.nbDims)); + b = shuffle->getOutput(0); + } + + auto add = ctx->net->addElementWise(*a, *b, nvinfer1::ElementWiseOperation::kSUM); + TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n); + + return add->getOutput(0); +} + auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({ "aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)", @@ -21,15 +44,11 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() auto input = args[0].ITensorOrFreeze(ctx); auto w_ih = args[2].ITensorOrFreeze(ctx); auto w_hh = args[3].ITensorOrFreeze(ctx); - auto b_ih = args[4].ITensorOrFreeze(ctx); - auto b_hh = args[5].ITensorOrFreeze(ctx); LOG_DEBUG("Input tensor shape: " << input->getDimensions()); LOG_DEBUG("w_ih tensor shape: " << w_ih->getDimensions()); LOG_DEBUG("w_hh tensor shape: " << w_hh->getDimensions()); - LOG_DEBUG("b_ih tensor shape: " << b_ih->getDimensions()); - LOG_DEBUG("b_hh tensor shape: " << b_hh->getDimensions()); - + std::vector state; auto hx = args[1].IValue()->toListRef(); for (unsigned int i = 0; i < hx.size(); i++) { @@ -51,81 +70,56 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() // calculate first half of gates auto mm1 = ctx->net->addMatrixMultiply(*input, nvinfer1::MatrixOperation::kNONE, *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE); TRTORCH_CHECK(mm1, "Unable to create matrix multiplication node: " << *n); - auto mm1_out = mm1->getOutput(0); - auto mm1_dim = mm1_out->getDimensions(); - auto b_ih_dim = b_ih->getDimensions(); - - TRTORCH_CHECK(util::broadcastable(mm1_dim, b_ih_dim, false)); - if (util::toVec(mm1_dim) != util::toVec(b_ih_dim)) { - LOG_DEBUG("b_ih dimensions need to be reshaped"); - - auto shuffle = ctx->net->addShuffle(*b_ih); - TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_ih_dim), mm1_dim.nbDims)); - b_ih = shuffle->getOutput(0); - } - - auto add1 = ctx->net->addElementWise(*mm1_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM); - TRTORCH_CHECK(add1, "Unable to create ElementWise layer from node: " << *n); - auto add1_out = add2->getOutput(0); + auto out1 = !args[4].IValue()->isNone() ? add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n) : mm1_out; // calculate second half of gates - auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOE); + auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE); TRTORCH_CHECK(mm2, "Unable to create matrix multiplication node: " << *n); - auto mm2_out = mm2->getOutput(0); - auto mm2_dim = mm2_out->getDimensions(); - auto b_hh_dim = b_hh->getDimensions(); - - TRTORCH_CHECK(util::broadcastable(mm2_dim, b_hh_dim, false)); - if (util::toVec(mm2_dim) != util::toVec(b_hh_dim)) { - LOG_DEBUG("b_hh dimensions need to be reshaped"); - - auto shuffle = ctx->net->addShuffle(*b_hh); - TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_hh_dim), mm2_dim.nbDims)); - b_hh = shuffle->getOutput(0); - } - - auto add2 = ctx->net->addElementWise(*mm2_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM); - TRTORCH_CHECK(add2, "Unable to create ElementWise layer from node: " << *n); - auto add2_out = add2->getOutput(0); + auto out2 = !args[5].IValue()->isNone() ? add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n) : mm2_out; // gates - auto add3 = ctx->net->addElementWise(*add1_out, *add2_out, nvinfer1::ElementWiseOperation::kSUM); - TRTORCH_CHECK(add3, "Unable to create ElementWise layer from node: " << *n); - auto add3_out = add3->getOutput(0); + auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM); + TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n); + auto add_out = add->getOutput(0); // chunk Tensor into 4 parts and apply activation functions - auto dims = util::toVec(add3_out->getDimensions()); + auto dims = util::toVec(add_out->getDimensions()); auto batch = dims[0]; auto hidden = dims[1]/4; - auto size = util::toDims(std::vector({batch, hidden})); - auto stride = util::toDims(std::vector({1, 1})); + std::vector size_vec = {batch, hidden}; + std::vector stride_vec = {1, 1}; + std::vector offset0 = {0, 0}; + std::vector offset1 = {0, hidden}; + std::vector offset2 = {0, 2*hidden}; + std::vector offset3 = {0, 3*hidden}; + + auto size = util::toDims(size_vec); + auto stride = util::toDims(stride_vec); - auto slice1 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, 0})), size, stride); + auto slice1 = ctx->net->addSlice(*add_out, util::toDims(offset0), size, stride); TRTORCH_CHECK(slice1, "Unable to create Slice layer from node: " << *n); auto activ1 = ctx->net->addActivation(*slice1->getOutput(0), nvinfer1::ActivationType::kSIGMOID); TRTORCH_CHECK(activ1, "Unable to create sigmoid activation layer from node: " << *n); auto ingate = activ1->getOutput(0); - auto slice2 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, hidden})), size, stride); + auto slice2 = ctx->net->addSlice(*add_out, util::toDims(offset1), size, stride); TRTORCH_CHECK(slice2, "Unable to create Slice layer from node: " << *n); auto activ2 = ctx->net->addActivation(*slice2->getOutput(0), nvinfer1::ActivationType::kSIGMOID); TRTORCH_CHECK(activ2, "Unable to create sigmoid activation layer from node: " << *n); auto forgetgate = activ2->getOutput(0); - auto slice3 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, 2*hidden})), size, stride); + auto slice3 = ctx->net->addSlice(*add_out, util::toDims(offset2), size, stride); TRTORCH_CHECK(slice3, "Unable to create Slice layer from node: " << *n); auto activ3 = ctx->net->addActivation(*slice3->getOutput(0), nvinfer1::ActivationType::kTANH); TRTORCH_CHECK(activ3, "Unable to create tanh activation layer from node: " << *n); auto cellgate = activ3->getOutput(0); - auto slice4 = ctx->net->addSlice(*add3_out, util::toDims(std::vector({0, 3*hidden})), size, stride); + auto slice4 = ctx->net->addSlice(*add_out, util::toDims(offset3), size, stride); TRTORCH_CHECK(slice4, "Unable to create Slice layer from node: " << *n); auto activ4 = ctx->net->addActivation(*slice4->getOutput(0), nvinfer1::ActivationType::kSIGMOID); TRTORCH_CHECK(activ4, "Unable to create sigmoid activation layer from node: " << *n);