Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dictionary outputs in TorchScript tracer #27743

Open
f0k opened this issue Oct 11, 2019 · 6 comments
Open

Support dictionary outputs in TorchScript tracer #27743

f0k opened this issue Oct 11, 2019 · 6 comments
Assignees
Labels

Comments

@f0k
Copy link
Contributor

@f0k f0k commented Oct 11, 2019

馃殌 Feature

torch.jit.trace() currently requires traced functions to output plain tensors or tuples. It would be helpful if it also allowed functions that return dictionaries.

Minimal failing example:

import torch
torch.jit.trace(lambda x: {'out': x}, torch.arange(3))

Motivation

Having modules take and return dictionaries makes it easier to define architectures with multiple inputs and outputs that are not always guaranteed to be present (e.g., for multi-task learning). Currently, such modules cannot be traced, e.g., for visualizing their structure.

Pitch

Functions returning dictionaries should be handled similarly to (named) tuples. For an OrderedDict, values can be taken in their stored order. For a general dict, the order should probably be made deterministic by sorting the keys, so it stays consistent across re-running the interpreter.
(Disclaimer: I wanted to use this to add a graph to tensorboard; I don't know what the implications of this suggestion would be for TorchScript traces.)

Alternatives

The alternative would be for me to wrap the module/function in a function that converts the dictionary to a (named?) tuple before returning it, just for the torch.jit.trace() call. This would not give me a trace that creates a dictionary, but at least a trace useful for visualization.

Additional context

This was part of issue #16453, which also asked for dictionary inputs, which are possible by now:

import torch
torch.jit.trace(lambda x: x['input'], {'input': torch.arange(3)})

cc @suo

@eellison

This comment has been minimized.

Copy link
Contributor

@eellison eellison commented Oct 14, 2019

@wanchaol you were working on supporting named tuple outputs right ? That would cover this use case.

@thuyen

This comment has been minimized.

Copy link
Contributor

@thuyen thuyen commented Nov 26, 2019

@eellison: will this be part of 1.4 release?

@wanchaol

This comment has been minimized.

Copy link
Contributor

@wanchaol wanchaol commented Nov 26, 2019

@thuyen you can already pass named tuple as output of the tracer, it's just that the result of the traced function will be just plain tuple.

#29751 this will enable tracer to take named tuple as output and preserve the names. But there's on-going discussions/issues on this PR and will not go into the 1.4 release candidate that we are having soon. This will go into the next release after the 1.4 rc.

@thuyen

This comment has been minimized.

Copy link
Contributor

@thuyen thuyen commented Nov 26, 2019

@wanchaol: thanks, I care more about tracing dictionary output (as mentioned in the issues title). Does #29751 support that as well?

@wanchaol

This comment has been minimized.

Copy link
Contributor

@wanchaol wanchaol commented Nov 26, 2019

@wanchaol: thanks, I care more about tracing dictionary output (as mentioned in the issues title). Does #29751 support that as well?

No, supporting Dictionary as output will be a separate PR. I can put sometime on investigating how to do it after I merge #29751

@thuyen

This comment has been minimized.

Copy link
Contributor

@thuyen thuyen commented Nov 26, 2019

Got it. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants
You can鈥檛 perform that action at this time.