-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Description
🚀 Feature
Add support to fx.symbolic_trace for code including torch.memory_format
Motivation
MemoryFormat is used by multiple layers to support channels last (amongst others) tensors. It's very useful for improved performance.
Additional context
Here is a script to reproduce the behavior.
import torch
from torch.fx import symbolic_trace
class MyConv2d(torch.nn.Module):
def __init__(self):
super(MyConv2d, self).__init__()
self.conv2d = torch.nn.Conv2d(64, 3, (3, 3))
pass
def forward(self, inp):
if inp.is_contiguous(memory_format=torch.contiguous_format):
inp = inp.to(memory_format=torch.channels_last)
return self.conv2d(inp)
m = MyConv2d()
trace = symbolic_trace(m)
results in output
Traceback (most recent call last):
File "/data/home/cpuhrsch/tmp/asdf.py", line 16, in <module>
trace = symbolic_trace(m)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 874, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 584, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "/data/home/cpuhrsch/tmp/asdf.py", line 11, in forward
if inp.is_contiguous(memory_format=torch.contiguous_format):
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 272, in __call__
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 46, in create_proxy
kwargs_ = self.create_arg(kwargs)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 320, in create_arg
return super().create_arg(a)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 121, in create_arg
r[k] = self.create_arg(v)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/_symbolic_trace.py", line 320, in create_arg
return super().create_arg(a)
File "/fsx/users/cpuhrsch/repos/pytorch/torch/fx/proxy.py", line 132, in create_arg
raise NotImplementedError(f"argument of type: {type(a)}")
NotImplementedError: argument of type: <class 'torch.memory_format'>
Metadata
Metadata
Assignees
Labels
No labels