Description
🚀 Feature
torch.jit.trace()
currently requires traced functions to output plain tensors or tuples. It would be helpful if it also allowed functions that return dictionaries.
Minimal failing example:
import torch
torch.jit.trace(lambda x: {'out': x}, torch.arange(3))
Motivation
Having modules take and return dictionaries makes it easier to define architectures with multiple inputs and outputs that are not always guaranteed to be present (e.g., for multi-task learning). Currently, such modules cannot be traced, e.g., for visualizing their structure.
Pitch
Functions returning dictionaries should be handled similarly to (named) tuples. For an OrderedDict, values can be taken in their stored order. For a general dict, the order should probably be made deterministic by sorting the keys, so it stays consistent across re-running the interpreter.
(Disclaimer: I wanted to use this to add a graph to tensorboard; I don't know what the implications of this suggestion would be for TorchScript traces.)
Alternatives
The alternative would be for me to wrap the module/function in a function that converts the dictionary to a (named?) tuple before returning it, just for the torch.jit.trace()
call. This would not give me a trace that creates a dictionary, but at least a trace useful for visualization.
Additional context
This was part of issue #16453, which also asked for dictionary inputs, which are possible by now:
import torch
torch.jit.trace(lambda x: x['input'], {'input': torch.arange(3)})
cc @suo