Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 51 additions & 37 deletions torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"pipeline_module_split",
]

lib = torch.library.Library("aten", "IMPL")


def _override_torch_ops_for_zero_bubble():
class MmSeparateWeightGrad(torch.autograd.Function):
Expand Down Expand Up @@ -142,10 +144,10 @@ def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
bias_1 = AddmmSeparateBiasGrad.apply(bias, beta)
return AddmmPassThrough.apply(bias_1, mat1_1, mat2_1, beta, alpha)

# _fused_rms_norm operator: RMS normalization
class FusedRmsNormSeparateWeightGrad(torch.autograd.Function):
# rms_norm operator: RMS normalization
class RmsNormSeparateWeightGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
def forward(ctx, input, normalized_shape, weight, eps):
ctx.save_for_backward(input)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
Expand All @@ -155,104 +157,116 @@ def forward(ctx, input, weight, normalized_shape, eps):
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
# Compute normalized input for weight gradient
if grad_output is None:
return None, None, None, None
variance = input.pow(2).mean(-1, keepdim=True)
rstd = torch.rsqrt(variance + ctx.eps)
normalized = input * rstd
# Gradient w.r.t. weight: sum over batch dimension
grad_weight = (grad_output * normalized).sum(
dim=tuple(range(grad_output.ndim - 1))
)
return None, grad_weight, None, None
return None, None, grad_weight, None

class FusedRmsNormSeparateInputGrad(torch.autograd.Function):
class RmsNormSeparateInputGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
ctx.save_for_backward(weight)
def forward(ctx, input, normalized_shape, weight, eps):
(
ctx.save_for_backward(weight)
if weight is not None
else ctx.save_for_backward()
)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
return input

@staticmethod
def backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
# This is a placeholder - the actual gradient computation happens in PassThrough
# Here we just pass through the grad_output weighted by weight
return grad_output, None, None, None

class FusedRmsNormPassThrough(torch.autograd.Function):
class RmsNormPassThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
def forward(ctx, input, normalized_shape, weight, eps):
with torch._C._AutoDispatchBelowAutograd():
return torch.ops.aten._fused_rms_norm(
input, weight, normalized_shape, eps
)
return torch.rms_norm(input, normalized_shape, weight, eps)

@staticmethod
def backward(ctx, gO):
return gO, gO, None, None
return gO, None, gO, None

def split_fused_rms_norm(input, weight, normalized_shape, eps):
print("split fused_rms_norm")
weight_1 = FusedRmsNormSeparateWeightGrad.apply(
input.detach(), weight, normalized_shape, eps
def split_rms_norm(input, normalized_shape, weight=None, eps=None):
print("split rms_norm")
weight_1 = RmsNormSeparateWeightGrad.apply(
input.detach(), normalized_shape, weight, eps
)
input_1 = FusedRmsNormSeparateInputGrad.apply(
input, weight.detach(), normalized_shape, eps
input_1 = RmsNormSeparateInputGrad.apply(
input,
normalized_shape,
weight.detach() if weight is not None else None,
eps,
)
return FusedRmsNormPassThrough.apply(input_1, weight_1, normalized_shape, eps)
return RmsNormPassThrough.apply(input_1, normalized_shape, weight_1, eps)

# _grouped_mm operator: Grouped matrix multiplication for MoE
class GroupedMmSeparateMat2Grad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mat2):
def forward(ctx, input, mat2, offs, bias, out_dtype):
ctx.save_for_backward(input)
ctx.offs = offs
return mat2

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
# Gradient w.r.t. mat2 for grouped mm
# This is simplified - actual implementation may need group-wise computation
grad_mat2 = torch.ops.aten._grouped_mm.default(
input.transpose(-1, -2), grad_output, reduce="sum"
input.transpose(-1, -2), grad_output, offs=ctx.offs
)
return None, grad_mat2
return None, grad_mat2, None, None, None

class GroupedMmSeparateInputGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mat2):
def forward(ctx, input, mat2, offs, bias, out_dtype):
ctx.save_for_backward(mat2)
ctx.offs = offs
return input

@staticmethod
def backward(ctx, grad_output):
(mat2,) = ctx.saved_tensors
# Gradient w.r.t. input for grouped mm
grad_input = torch.ops.aten._grouped_mm.default(
grad_output, mat2.transpose(-1, -2), reduce="sum"
grad_output, mat2.transpose(-1, -2), offs=ctx.offs
)
return grad_input, None
return grad_input, None, None, None, None

class GroupedMmPassThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mat2, reduce="sum"):
def forward(ctx, input, mat2, offs, bias, out_dtype):
with torch._C._AutoDispatchBelowAutograd():
return torch.ops.aten._grouped_mm.default(input, mat2, reduce=reduce)
return torch.ops.aten._grouped_mm.default(
input, mat2, offs=offs, bias=bias, out_dtype=out_dtype
)

@staticmethod
def backward(ctx, gO):
return gO, gO, None
return gO, gO, None, None, None

def split_grouped_mm(input, mat2, reduce="sum"):
def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
print("split grouped_mm")
mat2_1 = GroupedMmSeparateMat2Grad.apply(input.detach(), mat2)
input_1 = GroupedMmSeparateInputGrad.apply(input, mat2.detach())
return GroupedMmPassThrough.apply(input_1, mat2_1, reduce)
mat2_1 = GroupedMmSeparateMat2Grad.apply(
input.detach(), mat2, offs, bias, out_dtype
)
input_1 = GroupedMmSeparateInputGrad.apply(
input, mat2.detach(), offs, bias, out_dtype
)
return GroupedMmPassThrough.apply(input_1, mat2_1, offs, bias, out_dtype)

lib = torch.library.Library("aten", "IMPL")
lib.impl("mm", split_mm, "Autograd")
lib.impl("addmm", split_addmm, "Autograd")
lib.impl("_fused_rms_norm", split_fused_rms_norm, "Autograd")
lib.impl("rms_norm", split_rms_norm, "Autograd")
lib.impl("_grouped_mm", split_grouped_mm, "Autograd")


Expand Down
Loading