Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #10: Kernel Fusion using torch.jit #36

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pipegoose/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pipegoose.nn.data_parallel.data_parallel import DataParallel
from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel
from pipegoose.nn.pipeline_parallel.pipeline_parallel import PipelineParallel
from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel
from pipegoose.nn.pipeline_parallel.pipeline_parallel import PipelineParallel
from pipegoose.nn.fusion import FusedLayer

201 changes: 201 additions & 0 deletions pipegoose/nn/fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import torch
from typing import Any, Type, Callable
from multimethod import overload
from torch import fx
from torch import Tensor
from torch.nn import functional as F

from torch.nn import GELU, Dropout, Module
from torch.nn.modules.dropout import _DropoutNd
from transformers.models.bloom.modeling_bloom import BloomGelu


class FusedLayer:
# Used to match layers in Parallel.module to their fused layer counterpart
represents: list[Type[Module]] = []
wraps: set[Callable] = []

# We pass the target_layer to give each fused layer the ability to copy its instantiation arguments
def __init__(self, target_layer: Module) -> None:
pass


def _parent_name(target: str) -> tuple[str, str]:
*parent, name = target.rsplit(".", 1)
return parent[0] if parent else "", name


def replace_node_module(node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, new_module)


@torch.jit.script
def _fused_gelu_fwd(input):
return (
input
* 0.5
* (
1.0
+ torch.tanh(
0.7978845608028654 * (input + 0.044715 * input * input * input)
)
)
)


@torch.jit.script
def _fused_gelu_bwd(g, input):
tanh_out = torch.tanh(0.7978845608028654 * input * (1 + 0.044715 * input * input))
ff = 0.5 * input * (
(1 - tanh_out * tanh_out)
* (0.7978845608028654 + 0.1070322244089 * input * input)
) + 0.5 * (1 + tanh_out)
return ff * g


@torch.jit.script
def _fused_bias_gelu_fwd(input, bias):
x = input + bias
return _fused_gelu_fwd(x)


@torch.jit.script
def _fused_bias_gelu_bwd(g, input, bias):
x = input + bias
return _fused_gelu_bwd(g, x)


from torch import nn

BASE_MODEL = nn.Sequential(
nn.Linear(10, 10),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(10, 10),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(10, 10)
)


@torch.jit.script
def fused_bias_dropout(x, bias, p, training, inplace):
# type: (Tensor, Tensor, float, bool, bool) -> Tensor
return F.dropout(x + bias, p=p, training=training, inplace=inplace)


# This is our next best bet, where we wrap the actual fused gelu in another module class
# And then call fused_gelu.apply, where we assume fusedgelu inherits from torch.autograd.Function
# It seems input is not a Tensor, but a tuple of Tensors, so we get to unpack it based on whether it has bias or not

class _FusedBiasGeluFn(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return _fused_bias_gelu_fwd(input, bias)

@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
return (tmp := _fused_bias_gelu_bwd(grad_output, input, bias)), tmp

class _FusedGeluFn(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return _fused_gelu_fwd(input)

@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
return _fused_gelu_bwd(grad_output, input)

class FusedBiasGelu(GELU, FusedLayer):
"""Fused gelu + bias function."""

represents = [GELU, BloomGelu]
approximate: str
wraps = [len]

@overload
def __init__(self, target_layer: GELU):
super().__init__()
self.approximate = target_layer.approximate

@overload
def __init__(self, target_layer): super().__init__()

@staticmethod
def forward(input):
return _FusedBiasGeluFn.apply(input)


class FusedGelu(GELU, FusedLayer):
represents = [GELU, BloomGelu]
approximate: str
wraps = [len]

@overload
def __init__(self, target_layer: GELU):
super().__init__()
self.approximate = target_layer.approximate

@overload
def __init__(self, target_layer): super().__init__()

@staticmethod
def forward(input):
return _FusedGeluFn.apply(input)


@torch.jit.script
def fused_bias_dropout(
input: Tensor,
bias: Tensor,
dropout_prob: float,
training: bool,
inplace: bool = False,
) -> Tensor:
# type: (Tensor, Tensor, float, bool, bool) -> Tensor
return F.dropout(input + bias, p=dropout_prob, training=training, inplace=inplace)


class _FusedDropoutFn(torch.autograd.Function):

@staticmethod
def forward(ctx, input, p, training, inplace):
ctx.save_for_backward(input)
return F.dropout(input, p, training, inplace)

class _FusedBiasDropoutFn(torch.autograd.Function):

@staticmethod
def forward(ctx, input, bias, p, training, inplace):
ctx.save_for_backward(input, bias)
ctx.p = p
ctx.training = training
ctx.inplace = inplace
return fused_bias_dropout(input, bias, p, training, inplace)

@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
return (tmp := _fused_bias_gelu_bwd(grad_output, input, bias)), tmp

class FusedDropout(_DropoutNd, FusedLayer):
"""
Fused dropout + bias function.
See: https://pytorch.org/docs/stable/_modules/torch/nn/modules/dropout.html#Dropout
"""

represents = [Dropout]

def __init__(self, target_layer: Dropout):
dropout_p = target_layer.p
inplace = target_layer.inplace
super().__init__(p=dropout_p, inplace=inplace)

def forward(self, input: Tensor):
return _FusedDropoutFn.apply(input, self.p, self.training, self.inplace)
54 changes: 52 additions & 2 deletions pipegoose/nn/parallel.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from abc import abstractclassmethod
from dataclasses import dataclass
from functools import partial
from typing import cast
from typing import cast, List
from copy import deepcopy

import torch
from torch import nn
import torch.fx as fx

from pipegoose.nn.fusion import FusedLayer, replace_node_module
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

torch.fx.wrap('len')

@dataclass
class ParallelMetadata:
Expand All @@ -18,6 +22,9 @@ class ParallelMetadata:

class Parallel:
"""A base class for a parallelized module."""
def __init__(self, module: nn.Module, parallel_context: ParallelContext):
self.module = module
self.parallel_context = parallel_context

@abstractclassmethod
def parallelize(self):
Expand Down Expand Up @@ -55,6 +62,47 @@ def _get_device(parallel_context: ParallelContext) -> int:
setattr(module, "to", partial(_to_device, module))
setattr(module, "cuda", partial(_to_cuda, module))

def _fuse(self, module: nn.Module, fused_layers: List[FusedLayer]) -> nn.Module:
module = deepcopy(self.module)
for name, child in module.named_modules():
for fused_layer in fused_layers:
if any(isinstance(child, r) for r in fused_layer.represents):
module._modules[name] = fused_layer(child)

self.module = module
return module

def fuse(self, fused_layers: List[FusedLayer]) -> nn.Module:
"""
In place fusion of the model's layers according to list of input layers defined in pipegoose.nn.fusion
"""
return self._fuse(self.module, fused_layers)


def fuse_fx(self, fused_layers: List[FusedLayer]) -> nn.Module:
# Collect functions to wrap in the tracer
autowrap_fns = tuple(set.union(*map(lambda l: set(l.wraps), fused_layers)))
# The arguments to the tracer should be configured based on the union of the
# FusedLayer's 'wraps' attribute, which defines the operations that their
# representations contain that are not torchscriptable, such as `len` in
# BloomGelu
graph = fx.Tracer(autowrap_functions=autowrap_fns).trace(self.module)
fx_model = fx.GraphModule(self.module, graph)
# Maps node.target to the module it represents
modules = dict(fx_model.named_modules())
new_graph = deepcopy(fx_model.graph)
for node in new_graph.nodes:
if node.op == "call_module":
for fused_layer in fused_layers:
if type(modules[node.target]) in fused_layer.represents:
if len(node.users) > 1: # Output used by other nodes
continue
original_layer = modules[node.target]
new_layer = fused_layer(original_layer)
replace_node_module(node, modules, new_layer)
node.replace_all_uses_with(node.target)

return fx.GraphModule(self.module, new_graph)

def _to_device(self, device: str):
"""Move a parallelized module to accelerators."""
Expand All @@ -71,7 +119,9 @@ def is_specific_device(device):
parallel_metadata = cast(ParallelMetadata, getattr(self, "parallel_metadata", None))

assert parallel_metadata is not None, "Module is not parallelized yet"
assert device in SUPPORTED_DEVICES, f"Device must be one of {SUPPORTED_DEVICES}, got {device}"
assert (
device in SUPPORTED_DEVICES
), f"Device must be one of {SUPPORTED_DEVICES}, got {device}"
assert not is_specific_device(
device
), f'Moving to a specific device {device} is not supported. pipegoose will handle device assignment automatically. Please use "cuda" instead'
Expand Down
19 changes: 15 additions & 4 deletions pipegoose/nn/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
class TensorParallel(Parallel):
"""Turn a 🤗 transformers model into a tensor parallel model."""

PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer]
PARALLELIZERS = [
EmbeddingParallelizer,
LinearParallelizer,
LayerNormParallelizer,
LMHeadParallelizer,
]

def __init__(self, module: nn.Module, parallel_context: ParallelContext):
self.module = module
Expand All @@ -35,7 +40,9 @@ def parallelize(self) -> nn.Module:
for module_name, leaf_module in leaf_modules:
parallelizer = self._find_parallelizer(module_name, leaf_module)
if parallelizer is not None:
parallelizer(module_name, leaf_module, module, self.parallel_context).parallelize()
parallelizer(
module_name, leaf_module, module, self.parallel_context
).parallelize()

self._save_metadata(module, self.parallel_context)

Expand All @@ -50,7 +57,9 @@ def _get_leaf_modules(self, model: nn.Module) -> List[Tuple[str, nn.Module]]:

return leaf_modules

def _find_parallelizer(self, module_name: str, module: nn.Module) -> Optional[ModuleParallelizer]:
def _find_parallelizer(
self, module_name: str, module: nn.Module
) -> Optional[ModuleParallelizer]:
for parallelizer in self.PARALLELIZERS:
if parallelizer.is_parallelizable(module_name, module):
return parallelizer
Expand All @@ -59,4 +68,6 @@ def _find_parallelizer(self, module_name: str, module: nn.Module) -> Optional[Mo
@torch.no_grad()
def deparallelize(self) -> nn.Module:
for module_name, module in self.module.named_modules():
self.PARALLELIZERS[module].deparallelize(module_name, module, self.parallel_context)
self.PARALLELIZERS[module].deparallelize(
module_name, module, self.parallel_context
)
Loading