Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

[ddp] AssertionError: torch.* op returned non-Tensor MaskedLMOutput call_module self_model #1236

@davidberard98

Description

@davidberard98

DistributedDataParallel is a torch.nn module, but it doesn't conform to some of the expectations for torch.nn modules (i.e. that return value is a tensor type). https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/variables/nn_module.py#L190 is taken because is_allowed(mod.__class__) is true. Then dynamo errors out because it expects a tensor type but gets something else.

Repro: https://github.com/pytorch/benchmark/blob/wconstab/ddp_experiments/ddp_experiments.py on hf_Bert with 2 nodes with inductor backend.
With pytorch at pytorch/pytorch#83333 and dynamo at #628. In addition, patch pytorch by replacing https://github.com/pytorch/pytorch/blob/d05f07494a9a32c63f9218c0e703764a02033bb9/torch/nn/parallel/distributed.py#L981 with a nullcontext (to work around pytorch/pytorch#93668)

Error:

 ========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/convert_frame.py", line 313, in _convert_frame_assert
    code = transform_code_object(frame.f_code, transform)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/bytecode_transformation.py", line 338, in transform_code_object
    transformations(instructions, code_options)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/convert_frame.py", line 301, in transform
    tracer.run()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 331, in run
    and self.step()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 304, in step
    getattr(self, inst.opname)(inst)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 154, in wrapper
    return inner_fn(self, inst)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 731, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/symbolic_convert.py", line 241, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/variables/nn_module.py", line 195, in call_function
    return variables.TensorVariable.create(
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/torchdynamo/variables/tensor.py", line 278, in create
    assert (
AssertionError: torch.* op returned non-Tensor MaskedLMOutput call_module self_model

========== The above exception occurred while processing the following code ==========

  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/_submit.py", line 11, in <module>
    submitit_main()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/submission.py", line 72, in submitit_main
    process_job(args.folder)
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/submission.py", line 54, in process_job
    result = delayed.result()
  File "/data/home/dberard/miniconda/envs/bench-fast/lib/python3.8/site-packages/submitit/core/utils.py", line 133, in result
    self._result = self.function(*self.args, **self.kwargs)
  File "ddp_experiments.py", line 151, in __call__
    return trainer_class(self.args, model_class, model_args=self.model_args).measure()
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/distributed/core_model/trainer.py", line 79, in measure
    self.benchmark.invoke()
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/model.py", line 190, in invoke
    self.train()
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/framework/huggingface/model_factory.py", line 119, in train
    def train(self):
  File "/fsx/users/dberard/scratch-local/bench-fast/benchmark/torchbenchmark/util/framework/huggingface/model_factory.py", line 120, in train
    outputs = self.model(**self.example_inputs)
==========

Full error: https://gist.github.com/davidberard98/e5054d628c0855cb560837600cd35399

This is my best effort at a minimal repro, but it fails with a different error.

import torch
import torchdynamo

import logging
torchdynamo.config.verbose = True
torchdynamo.config.log_level = logging.DEBUG

class CustomObject:
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c


class MyModule(torch.nn.Module):
    def __init__(self, rand):
        self.rand = rand

    def forward(self, x, y, obj):
        return CustomObject(
            x+y*obj.a,
            (x+y).relu() - obj.b,
            (x+y).sin() * self.rand + obj.c,
        )

# simulate the fact that is_allowed(DistributedDataParallel) returns true
torchdynamo.allow_in_graph(MyModule)

mod = MyModule(torch.rand((10, 10), device='cuda'))
x, y = [torch.rand((10, 10), device='cuda') for _ in range(2)]

obj = CustomObject(*[torch.rand((10, 10), device='cuda') for _ in range(3)])
args = [x, y, obj]

@torchdynamo.optimize("inductor")
def fn(args):
    return mod(*args)

fn(args)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions