Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");

// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
for (int i = expandedDims.nbDims - 1; i >= 0; --i) {
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
int64_t offset = expandedDims.nbDims - 1 - i;
int64_t dim = input_dims.nbDims - 1 - offset;
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
Expand All @@ -41,10 +41,10 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor
if (num_expand_dims > 0) {
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = expandedDims.nbDims;
for (int i = 0; i < num_expand_dims; i++) {
for (int64_t i = 0; i < num_expand_dims; i++) {
reshape_dims.d[i] = 1;
}
for (int i = 0; i < input_dims.nbDims; i++) {
for (int64_t i = 0; i < input_dims.nbDims; i++) {
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
}
// Add a reshape layer to expand dims
Expand All @@ -60,7 +60,7 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor

// Set the stride of non singleton dimension to 1
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
for (int i = 0; i < expandedDims.nbDims; i++) {
for (int64_t i = 0; i < expandedDims.nbDims; i++) {
strides_vec[i] = (in->getDimensions().d[i] != 1);
}

Expand Down Expand Up @@ -104,16 +104,16 @@ auto expand_registrations TRTORCH_UNUSED =
auto input_dims = in->getDimensions();
auto repeats = args[1].unwrapToIntList().vec();
TRTORCH_CHECK(
repeats.size() >= input_dims.nbDims,
static_cast<int64_t>(repeats.size()) >= input_dims.nbDims,
"Number of repeat dimensions cannot be smaller than number of input dimensions");
auto num_expand_dims = repeats.size() - input_dims.nbDims;
if (num_expand_dims > 0) {
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = repeats.size();
for (int i = 0; i < num_expand_dims; i++) {
for (size_t i = 0; i < num_expand_dims; i++) {
reshape_dims.d[i] = 1;
}
for (int i = 0; i < input_dims.nbDims; i++) {
for (int64_t i = 0; i < input_dims.nbDims; i++) {
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
}
// Add a reshape layer to expand dims
Expand All @@ -127,9 +127,9 @@ auto expand_registrations TRTORCH_UNUSED =

// Concat across all repeat axes.
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
for (int i = repeats.size() - 1; i >= 0; --i) {
for (int64_t i = repeats.size() - 1; i >= 0; --i) {
std::vector<nvinfer1::ITensor*> tensors_vec;
for (int j = 0; j < repeats[i]; j++) {
for (int64_t j = 0; j < repeats[i]; j++) {
tensors_vec.push_back(in);
}
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
Expand All @@ -139,7 +139,7 @@ auto expand_registrations TRTORCH_UNUSED =

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);

LOG_DEBUG("Repeat layer output tensor shape: " << in->getDimensions());
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());

return true;
}});
Expand Down