-
Notifications
You must be signed in to change notification settings - Fork 129
Cifar10 example not working for dynamo training #1930
Description
🐛 Describe the bug
Tried converting the Cifar10 model example to use dynamo for training. Decorated the train loop with @dynamo.optimze()
, it is falling back to eager mode with falling warning during training with Bento notebook (N2847095):
[2022-11-26 23:05:49,311] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-11-26 23:05:49,314] torch._inductor.compile_fx: [WARNING] Aot Autograd is not safe to run, so falling back to eager
[2022-11-26 23:05:49,567] torch._dynamo.optimizations.training: [WARNING] Unable to use Aot Autograd because of presence of mutation
[2022-11-26 23:05:49,568] torch._inductor.compile_fx: [WARNING] Aot Autograd is not safe to run, so falling back to eager
Trying on Google Colab notebook gives following error:
AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torchvision/datasets/cifar.py", line 118, in getitem
img = self.transform(img)
File "/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py", line 95, in call
img = t(img)
File "/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py", line 135, in call
return F.to_tensor(pic)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/eval_frame.py", line 307, in catch_errors
return callback(frame, cache_size)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/convert_frame.py", line 457, in _convert_frame
result = inner_convert(frame, cache_size)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/convert_frame.py", line 101, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/utils.py", line 90, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
frame,
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/convert_frame.py", line 385, in _compile
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/convert_frame.py", line 373, in transform
tracer.run()
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/symbolic_convert.py", line 1615, in run
super().run()
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/symbolic_convert.py", line 484, in run
and self.step()
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/symbolic_convert.py", line 454, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/symbolic_convert.py", line 281, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/symbolic_convert.py", line 911, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/symbolic_convert.py", line 389, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/variables/torch.py", line 439, in call_function
**options,
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/variables/builder.py", line 653, in wrap_fx_proxy
*options,
File "/usr/local/lib/python3.7/dist-packages/torch/_dynamo/variables/builder.py", line 808, in wrap_fx_proxy_cls
+ f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
AssertionError: torch. op returned non-Tensor dtype call_function
from user code:
File "/usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py", line 145, in to_tensor
default_float_dtype = torch.get_default_dtype()
Error logs
No response
Minified repro
No response