Skip to content

feat(//core/converters): Add conversion support for torch.narrow() #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2020

Conversation

peri044
Copy link
Collaborator

@peri044 peri044 commented Oct 7, 2020

Description

Adding conversion support for operator torch.narrow(input, axis, start, length)
The variable start can be a scalar or tensor.
The test case when start is a tensor hasn't been implemented in the PR as the IR for this testcase wasn't being parsed by torch jit parser (due to start being prim::constant(value={2} for example). However, using trtorchexec, I was able to test that case as well and it works.

Fixes #151

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

@peri044 peri044 changed the title Add conversion support for torch.narrow() feat(//core/converters): Add conversion support for torch.narrow() Oct 7, 2020
@narendasan
Copy link
Collaborator

Can you explain more about the missing test case issue?

@peri044
Copy link
Collaborator Author

peri044 commented Oct 7, 2020

For the following graph

class Narrow(torch.nn.Module):
    def __init__(self):
        super(Narrow, self).__init__()
        
    def forward(self, x):
        return torch.narrow(x, 1, torch.tensor(2), 2)

The JIT representation obtained from print logs (of trtorchexec after graph lowering) is

graph(%x.1 : Tensor):
              %2 : Long() = prim::Constant[value={2}]()
              %3 : int = prim::Constant[value=1]() # narrow.py:13:31
              %4 : int = prim::Constant[value=2]() # narrow.py:13:47
              %5 : Tensor = aten::narrow(%x.1, %3, %2, %4) # narrow.py:13:15
              return (%5)

The error returned by torch::jit::parseIR(graph, &*g) in the test case is as follows

C++ exception with description "
Could not parse literal{:

      graph(%x.1 : Tensor):
              %2 : Long() = prim::Constant[value={2}]()
                                                 ~ <--- HERE

Looks like when there is a constant tensor ({2}) instead of scalar, this issue occurs.

@narendasan
Copy link
Collaborator

narendasan commented Oct 7, 2020

Create the tensor outside in the body of the test and pass it in as an additional input to the graph instead of an inline constant

@peri044
Copy link
Collaborator Author

peri044 commented Oct 7, 2020

I tried that but I faced a different error since args[2] (start) won't be an IValue, but an ITensor. So args[2].ITensor() (nvinfer1::ITensor instance) needs to be converted as a scalar integer. https://github.com/NVIDIA/TRTorch/pull/188/files#diff-9d97a3640ddc39b7d6f7cf3e769dfa4dR91.
The startIdx should be a scalar integer to calculate the indices that need to be passed to gather layer.
If we use args[2].ITensor(), we need to change the converter code a bit, but I'm not sure if this would be general enough for a common usecase.

@narendasan
Copy link
Collaborator

narendasan commented Oct 7, 2020

Could we use ITensorOrFreeze? Seems like we need it as a constant anyway or can we not do the broadcasting in TensorRT?

@peri044
Copy link
Collaborator Author

peri044 commented Oct 7, 2020

ITensorOrFreeze returns a constant ITensor as output. I was looking for an integer value start that can be passed through torch::arange to generate a torch tensor starting from start to start + length. This torch tensor (indices) is then converted as a constant weight tensor (nvinfer::ITensor) to tensorrt using converter::Weights.
Maybe we can have an equivalent logic for torch::arange to be implemented in TRT, but I felt the torch approach was simpler (maybe).

@narendasan
Copy link
Collaborator

Is the Tensor guaranteed to be a 0d tensor? Also if you look at the 2nd concat test you should find a way to pass the static tensor as a argument that should not be a ITensor

@peri044
Copy link
Collaborator Author

peri044 commented Oct 9, 2020

As per our discussion, here is an issue to track the missing test case.

return true;
}
}).pattern({
"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there not a specialization for this op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for the scalar case of start, there is no separate specialization. It's just aten::narrow in the IR.

auto axis = args[1].unwrapToInt();
torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32);
// TODO: Is there a better way to get data from 0-dim tensor ?
int startIdx = static_cast<int*>(start.data_ptr())[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the specialized type int32_t/int64_t vs int where you can (some TRT stuff requires int)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !!

auto axis = args[1].unwrapToInt();
torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32);
// TODO: Is there a better way to get data from 0-dim tensor ?
int startIdx = static_cast<int*>(start.data_ptr())[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just index on the tensor itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot index a 0-dim tensor here, but I replaced it with an alternative as per torch recommendation.

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

Cleaner way to access tensor data in narrow op

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
@narendasan
Copy link
Collaborator

@peri044 is this good to merge, minus the test that is missing?

@peri044
Copy link
Collaborator Author

peri044 commented Oct 22, 2020

Yes. This is good to merge !! The issue tracks and details about the missing test case

@narendasan narendasan merged commit 5e1b842 into pytorch:master Oct 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for aten::narrow
2 participants