Skip to content

mark_forward_method does not work with ModelParallelStrategy #20710

@tonyf

Description

@tonyf

Bug description

When using the ModelParallelStrategy, methods annotated with mark_forward_method raise an exception if the function signature does not match that of the module's forward method. This fails specifically when the number of args/kwargs differ between the functions.

For calling generate here would fail in an FSDP2 setting with the error TypeError: Model.forward got an unexpected keyword argument cfg

class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return x

    def generate(self, x, y, cfg: int = 0.5):
        z_1 = self.forward(x, y)
        z_2 = self.foward(x, torch.zeros_like(y))
        ...

What version are you seeing the problem on?

v2.5

Error messages and logs

        │
[rank0]: │   473 │   │   ):                                                                                                                                      │
[rank0]: │   474 │   │   │   self.callbacks.on_validation_step_start(self, batch_idx)                                                                            │
[rank0]: │   475 │   │   │                                                                                                                                       │
[rank0]: │ ❱ 476 │   │   │   result = self.validation_step(batch, batch_idx)                                                                                     │
[rank0]: │   477 │   │   │   self.callbacks.on_validation_step_end(self, result, batch_idx)                                                                      │
[rank0]: │   478 │   │                                                                                                                                           │
[rank0]: │   479 │   │   result = self.on_validation_epoch_end()                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/models/flow_matching/stage_1_train.py:112 in validation_step                                                              │
[rank0]: │                                                                                                                                                       │
[rank0]: │   109 │   │   B, _, T, H, W = samples.shape                                                                                                           │
[rank0]: │   110 │   │   ct, ch, cw = self.autoencoder.compression                                                                                               │
[rank0]: │   111 │   │                                                                                                                                           │
[rank0]: │ ❱ 112 │   │   samples = self.model.sample(                                                                                                            │
[rank0]: │   113 │   │   │   shape=(B, (T - 1) // ct + 1, H // ch, W // cw, self.autoencoder.latent_dim),                                                        │
[rank0]: │   114 │   │   │   text=text_embeds,                                                                                                                   │
[rank0]: │   115 │   │   │   sample_steps=self.config.sample_steps,                                                                                              │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:197 in call_forward_module                                │
[rank0]: │                                                                                                                                                       │
[rank0]: │   194 │   │   def call_forward_module(*args: Any, **kwargs: Any) -> Any:                                                                              │
[rank0]: │   195 │   │   │   # Patch the original_module's forward, so we can redirect the arguments back                                                        │
[rank0]: │   196 │   │   │   self._original_module.forward = wrapped_forward                                                                                     │
[rank0]: │ ❱ 197 │   │   │   return self.forward(*args, **kwargs)                                                                                                │
[rank0]: │   198 │   │                                                                                                                                           │
[rank0]: │   199 │   │   return call_forward_module                                                                                                              │
[rank0]: │   200                                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:136 in forward                                            │
[rank0]: │                                                                                                                                                       │
[rank0]: │   133 │   │   args, kwargs = precision.convert_input((args, kwargs))                                                                                  │
[rank0]: │   134 │   │                                                                                                                                           │
[rank0]: │   135 │   │   with precision.forward_context():                                                                                                       │
[rank0]: │ ❱ 136 │   │   │   output = self._forward_module(*args, **kwargs)                                                                                      │
[rank0]: │   137 │   │                                                                                                                                           │
[rank0]: │   138 │   │   output = precision.convert_output(output)                                                                                               │
[rank0]: │   139                                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                  │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1736 │   │   if self._compiled_call_impl is not None:                                                                                               │
[rank0]: │   1737 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                             │
[rank0]: │   1738 │   │   else:                                                                                                                                  │
[rank0]: │ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                            │
[rank0]: │   1740 │                                                                                                                                              │
[rank0]: │   1741 │   # torchrec tests the code consistency with the following code                                                                              │
[rank0]: │   1742 │   # fmt: off                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                          │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1747 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks                                                        │
[rank0]: │   1748 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                        │
[rank0]: │   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                        │
[rank0]: │ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                               │
[rank0]: │   1751 │   │                                                                                                                                          │
[rank0]: │   1752 │   │   result = None                                                                                                                          │
[rank0]: │   1753 │   │   called_always_called_hooks = set()                                                                                                     │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:574 in _fn                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │    571 │   │   │   )                                                                                                                                  │
[rank0]: │    572 │   │   │                                                                                                                                      │
[rank0]: │    573 │   │   │   try:                                                                                                                               │
[rank0]: │ ❱  574 │   │   │   │   return fn(*args, **kwargs)                                                                                                     │
[rank0]: │    575 │   │   │   finally:                                                                                                                           │
[rank0]: │    576 │   │   │   │   # Restore the dynamic layer stack depth if necessary.                                                                          │
[rank0]: │    577 │   │   │   │   torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                  │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1736 │   │   if self._compiled_call_impl is not None:                                                                                               │
[rank0]: │   1737 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                             │
[rank0]: │   1738 │   │   else:                                                                                                                                  │
[rank0]: │ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                            │
[rank0]: │   1740 │                                                                                                                                              │
[rank0]: │   1741 │   # torchrec tests the code consistency with the following code                                                                              │
[rank0]: │   1742 │   # fmt: off                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                          │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1747 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks                                                        │
[rank0]: │   1748 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                        │
[rank0]: │   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                        │
[rank0]: │ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                               │
[rank0]: │   1751 │   │                                                                                                                                          │
[rank0]: │   1752 │   │   result = None                                                                                                                          │
[rank0]: │   1753 │   │   called_always_called_hooks = set()                                                                                                     │
[rank0]: ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
[rank0]: TypeError: Rem.forward() got an unexpected keyword argument 'shape'

Environment

Current environment
#- PyTorch Lightning Version: 2.5.0.post
#- PyTorch Version: 2.6.0+cu124
#- Python version: 3.11
#- OS: Linux
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: 8xH100
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

cc @justusschock @lantiga

Activity

added
bugSomething isn't working
needs triageWaiting to be triaged by maintainers
on Apr 12, 2025
added
distributedGeneric distributed-related topic
and removed
needs triageWaiting to be triaged by maintainers
on Sep 5, 2025
deependujha

deependujha commented on Sep 12, 2025

@deependujha
Collaborator

Hi @tonyf, I wasn’t able to reproduce the issue.

Here’s the code I ran (working as expected). I may be missing something. Could you share a minimal repro that triggers the bug?
import torch
import lightning as L
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed._composable.fsdp import fully_shard
from lightning.fabric.strategies import ModelParallelStrategy

def parallelize_feedforward(model, device_mesh):
    # Lightning will set up a device mesh for you
    # Here, it is 2-dimensional
    tp_mesh = device_mesh["tensor_parallel"]
    dp_mesh = device_mesh["data_parallel"]

    if tp_mesh.size() > 1:
        # Use PyTorch's distributed tensor APIs to parallelize the model
        plan = {
            "w1": ColwiseParallel(),
            "w2": RowwiseParallel(),
            "w3": ColwiseParallel(),
        }
        parallelize_module(model, tp_mesh, plan)

    if dp_mesh.size() > 1:
        # Use PyTorch's FSDP2 APIs to parallelize the model
        fully_shard(model.w1, mesh=dp_mesh)
        fully_shard(model.w2, mesh=dp_mesh)
        fully_shard(model.w3, mesh=dp_mesh)
        fully_shard(model, mesh=dp_mesh)

    return model


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

    def generate(self, device):
        sample = torch.randn(10).reshape(1,-1).to(device)
        return self(sample)


if __name__ == "__main__":
    strategy = ModelParallelStrategy(
        parallelize_fn=parallelize_feedforward,
        # Define the size of the 2D parallelism
        # Set these to "auto" (default) to apply TP intra-node and FSDP inter-node
        data_parallel_size=2,
        tensor_parallel_size=2,
    )
    fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
    fabric.launch()
    model = FeedForward(10, 1024)
    model = torch.compile(model)
    model = fabric.setup(model)

    # You must mark special forward methods explicitly:
    model.mark_forward_method("generate")

    output = model.generate(fabric.device)
    print(f"{output=}")

Using lightning studio t4 machine with 4 gpus.
Thanks!

cc: @SkafteNicki

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdistributedGeneric distributed-related topicstrategy: fsdpFully Sharded Data Parallelver: 2.5.xwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Participants

      @tonyf@SkafteNicki@deependujha

      Issue actions

        `mark_forward_method` does not work with `ModelParallelStrategy` · Issue #20710 · Lightning-AI/pytorch-lightning