Skip to content

Support dictionary outputs in TorchScript tracer #27743

Closed
@f0k

Description

@f0k

🚀 Feature

torch.jit.trace() currently requires traced functions to output plain tensors or tuples. It would be helpful if it also allowed functions that return dictionaries.

Minimal failing example:

import torch
torch.jit.trace(lambda x: {'out': x}, torch.arange(3))

Motivation

Having modules take and return dictionaries makes it easier to define architectures with multiple inputs and outputs that are not always guaranteed to be present (e.g., for multi-task learning). Currently, such modules cannot be traced, e.g., for visualizing their structure.

Pitch

Functions returning dictionaries should be handled similarly to (named) tuples. For an OrderedDict, values can be taken in their stored order. For a general dict, the order should probably be made deterministic by sorting the keys, so it stays consistent across re-running the interpreter.
(Disclaimer: I wanted to use this to add a graph to tensorboard; I don't know what the implications of this suggestion would be for TorchScript traces.)

Alternatives

The alternative would be for me to wrap the module/function in a function that converts the dictionary to a (named?) tuple before returning it, just for the torch.jit.trace() call. This would not give me a trace that creates a dictionary, but at least a trace useful for visualization.

Additional context

This was part of issue #16453, which also asked for dictionary inputs, which are possible by now:

import torch
torch.jit.trace(lambda x: x['input'], {'input': torch.arange(3)})

cc @suo

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions