-
Notifications
You must be signed in to change notification settings - Fork 372
feat(//core/converters): Add conversion support for torch.narrow() #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Can you explain more about the missing test case issue? |
For the following graph
The JIT representation obtained from print logs (of trtorchexec after graph lowering) is
The error returned by
Looks like when there is a constant tensor ({2}) instead of scalar, this issue occurs. |
Create the tensor outside in the body of the test and pass it in as an additional input to the graph instead of an inline constant |
I tried that but I faced a different error since args[2] (start) won't be an IValue, but an ITensor. So args[2].ITensor() ( |
Could we use ITensorOrFreeze? Seems like we need it as a constant anyway or can we not do the broadcasting in TensorRT? |
|
Is the Tensor guaranteed to be a 0d tensor? Also if you look at the 2nd concat test you should find a way to pass the static tensor as a argument that should not be a ITensor |
As per our discussion, here is an issue to track the missing test case. |
return true; | ||
} | ||
}).pattern({ | ||
"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there not a specialization for this op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah for the scalar case of start
, there is no separate specialization. It's just aten::narrow
in the IR.
auto axis = args[1].unwrapToInt(); | ||
torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32); | ||
// TODO: Is there a better way to get data from 0-dim tensor ? | ||
int startIdx = static_cast<int*>(start.data_ptr())[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the specialized type int32_t/int64_t
vs int
where you can (some TRT stuff requires int)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !!
auto axis = args[1].unwrapToInt(); | ||
torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32); | ||
// TODO: Is there a better way to get data from 0-dim tensor ? | ||
int startIdx = static_cast<int*>(start.data_ptr())[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just index on the tensor itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot index a 0-dim tensor here, but I replaced it with an alternative as per torch recommendation.
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Cleaner way to access tensor data in narrow op Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
@peri044 is this good to merge, minus the test that is missing? |
Yes. This is good to merge !! The issue tracks and details about the missing test case |
Description
Adding conversion support for operator
torch.narrow(input, axis, start, length)
The variable
start
can be a scalar or tensor.The test case when
start
is a tensor hasn't been implemented in the PR as the IR for this testcase wasn't being parsed by torch jit parser (due to start being prim::constant(value={2} for example). However, using trtorchexec, I was able to test that case as well and it works.Fixes #151
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: