-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
❓ [Question] Is there support for optional arguments in model's forward()
?
#772
Comments
@lhai37 I don't think we support optional tensors at the moment. cc @narendasan. We expect inputs and outputs of a module to be torch::Tensors. Can you share how your torchscript model looks like ? I tried to convert the following to TS but
|
@peri044 Your code fails because you are doing |
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
This is not currently supported and requires the next phase collections feature (#629). The issue is we need to be able to generate the torchscript code to manage mapping from function input to tensorrt input when potentially any arbitrary input could be None. |
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
❓ Question
Is there support for optional arguments in model's
forward()
? For example, I have the following:def forward(self, x, y: Optional[Tensor] = None):
wherey
is an optional tensor. The return result isx + y
ify
is provided, otherwise justx
.What you have already tried
I added a second
torch_tensorrt.Input()
in the input spec, then at inference time got the error:Expected dimension specifications for all input tensors, but found 1 input tensors and 2 dimension specs
I then removed the
Optional
annotation and just pass inNone
or the actual tensor fory
. WhenNone
is passed in, I got the error:RuntimeError: forward() Expected a value of type 'Tensor' for argument 'input_1' but instead found type 'NoneType'.
I also tried passing in just 1 argument for
x
, and got:RuntimeError: forward() is missing value for argument 'input_1'
Environment
conda
,pip
,libtorch
, source):pip
Additional context
The text was updated successfully, but these errors were encountered: