Skip to content

FX support for torch.memory_format #62498

@cpuhrsch

Description

@cpuhrsch

🚀 Feature

Add support to fx.symbolic_trace for code including torch.memory_format

Motivation

MemoryFormat is used by multiple layers to support channels last (amongst others) tensors. It's very useful for improved performance.

Additional context

Here is a script to reproduce the behavior.

import torch
from torch.fx import symbolic_trace

class MyConv2d(torch.nn.Module):
    def __init__(self):
        super(MyConv2d, self).__init__()
        self.conv2d = torch.nn.Conv2d(64, 3, (3, 3))
        pass

    def forward(self, inp):
        if inp.is_contiguous(memory_format=torch.contiguous_format):
            inp = inp.to(memory_format=torch.channels_last)
        return self.conv2d(inp)

m = MyConv2d()
trace  = symbolic_trace(m)

results in output

Traceback (most recent call last):
  File "/data/home/cpuhrsch/tmp/asdf.py", line 16, in <module>
    trace  = symbolic_trace(m)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 874, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 584, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "/data/home/cpuhrsch/tmp/asdf.py", line 11, in forward
    if inp.is_contiguous(memory_format=torch.contiguous_format):
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 272, in __call__
    return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 46, in create_proxy
    kwargs_ = self.create_arg(kwargs)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 320, in create_arg
    return super().create_arg(a)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 121, in create_arg
    r[k] = self.create_arg(v)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 320, in create_arg
    return super().create_arg(a)
  File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 132, in create_arg
    raise NotImplementedError(f"argument of type: {type(a)}")
NotImplementedError: argument of type: <class 'torch.memory_format'>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions