Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/inference/opt_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import argparse
import os

import torch
import pippy
import pippy.fx
from pippy import run_pippy
from pippy.hf import PiPPyHFTracer, inject_pipeline_forward
from transformers import AutoTokenizer, OPTForCausalLM


pippy.fx.Tracer.proxy_buffer_attributes = True

gigabyte_size = 1024 ** 3


def format_to_gb(item, precision=4):
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
metric_num = item / gigabyte_size
metric_num = round(metric_num, ndigits=precision)
return metric_num


def print_mem_usage():
memory_reserved = format_to_gb(torch.cuda.memory_reserved())
memory_allocated = format_to_gb(torch.cuda.memory_allocated())
print(
f"memory_reserved: {memory_reserved} GB, "
f"memory_allocated: {memory_allocated} GB"
)


def get_number_of_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def generate_input(args):
bs = args.batch_size * args.chunks
seq_length = args.seq_length
model_config = args.model.config
torch.manual_seed(args.rank)

inp = torch.empty(bs, seq_length, dtype=torch.long, device=args.device).random_(model_config.vocab_size)
model_input_dict = {
"input_ids": inp,
}

return model_input_dict


def run_all(pp_ranks, args):
model = args.model
model.to(args.device)
model.eval()
model.config.use_cache = False # don't output `past_key_values`
num_ranks = len(pp_ranks)

if args.rank == 0:
print(model.config)
print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M")

split_policy = pippy.split_into_equal_size(num_ranks)

model_input_dict = generate_input(args)
# Use default value for other kwargs than those in `model_input_dict`
concrete_args = pippy.create_default_args(
model,
except_keys=model_input_dict.keys(),
)

pipe_driver, stage_mod = pippy.all_compile(
model,
num_ranks,
args.chunks,
split_policy=split_policy,
tracer=PiPPyHFTracer(),
concrete_args=concrete_args,
)

params = get_number_of_params(stage_mod)
print(f"submod_{args.rank} {params // 10 ** 6}M params")

if args.rank != 0:
return

# Master continues
print_mem_usage()

# Inject pipeline driver's forward function back to original model to support HF's `generate()` method
inject_pipeline_forward(model, pipe_driver)

# OPT generate
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
prompt = "Hey, are you consciours? Can you talk to me?"
input = tokenizer(prompt, return_tensors="pt")

input_ids = input["input_ids"].to(args.device)
outputs = model.generate(input_ids, max_length=30)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(response)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
parser.add_argument('--model_name', type=str, default='facebook/opt-350m')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--chunks', type=int, default=1)
parser.add_argument('--seq_length', type=int, default=16)
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
parser.add_argument('--pp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))

args = parser.parse_args()

assert args.world_size % args.pp_group_size == 0

# Main process loads model
print(f"Loading model {args.model_name}")
if 'opt' in args.model_name:
model = OPTForCausalLM.from_pretrained(args.model_name, use_cache=False)
else:
raise ValueError(f"Unsupported model: {args.model_name}")
args.model = model

args.gspmd = 1
run_pippy(run_all, args)
16 changes: 15 additions & 1 deletion pippy/PipelineDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,8 +1697,22 @@ def __init__(
node.name, Parameter.POSITIONAL_OR_KEYWORD, default=default
)
)

# We are building a safety net here in case user passes in extra arguments than those defined as variable
# arguments (i.e. non-concrete args) at the tracing phase
# TODO: Remove this safety net
traced_args = [p.name for p in parameters]
filtered_kwargs = {k: v for k, v in kwargs.items() if k in traced_args}
if len(filtered_kwargs) != len(kwargs):
extra_args = kwargs.keys() - filtered_kwargs.keys()
warnings.warn(
f"Received extra arguments: {extra_args}. "
f"They might have already been given a concrete value during pipeline compilation via `concrete_args`. "
f"We will ignore the current inputs and use the values given during compilation."
)

sig = Signature(parameters)
bound_args = sig.bind(*args, **kwargs)
bound_args = sig.bind(*args, **filtered_kwargs)
bound_args.apply_defaults()
self.args = bound_args.args
self.args_iter = iter(self.args)
Expand Down
2 changes: 2 additions & 0 deletions pippy/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
PiPPySeq2SeqTrainingArguments,
PiPPyTrainer,
PiPPySeq2SeqTrainer,
inject_pipeline_forward,
)

__all__ = [
Expand All @@ -13,4 +14,5 @@
"PiPPySeq2SeqTrainingArguments",
"PiPPyTrainer",
"PiPPySeq2SeqTrainer",
"inject_pipeline_forward",
]
44 changes: 44 additions & 0 deletions pippy/hf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
)
from transformers.utils import cached_property

from pippy.PipelineDriver import PipelineDriverBase


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -283,3 +285,45 @@ def trace(self, *args, **kwargs):
elif getattr(node.target, "_orig", None) == torch.zeros:
node.target = torch_zeros_wrapper
return graph


# The `DotDict` class adds dot notation access to dictionary attributes.
class DotDict(dict):
def __getattr__(self, attr):
return self.get(attr)

def __setattr__(self, key, value):
self.__setitem__(key, value)

def __delattr__(self, item):
self.__delitem__(item)


# This is an experimental utility function that replaces the original model's forward method with PiPPy's PipelineDriver
# forward method. It is used to support HuggingFace's `generate()` method, which is defined in a `GenerationMixin`
# class that `PreTrainedModel` inherits from. We choose this replacement path instead of writing our own `generate()`
# method because the `generate()` method would call into many `GenerationMixin` APIs that may be implemented differently
# by each model.
def inject_pipeline_forward(
model: torch.nn.Module,
pipe_driver: PipelineDriverBase,
):
logging.info(
f"Inserting PiPPy pipeline forward into model {model._get_name()}"
)
# Inject pipeline driver as a member object of original model
setattr(model, "pippy_pipeline_driver", pipe_driver)

# Define a new forward method that uses PiPPy's pipeline driver
def pippy_forward(self, *args, **kwargs):
output = self.pippy_pipeline_driver(*args, **kwargs)
if isinstance(output, dict):
# Add dot access if output is a dictionary. The output of a traced HF model is a traditional dict which has
# only [`key`] access. The wrapping is needed for Transformer versons >= 4.28 which access attributes of
# output via dot notation, such as `output.logits`. See for example the `generate()` method and
# `modeling_output.py`.
output = DotDict(output)
return output

# Replace the forward method in original model
setattr(model, "forward", types.MethodType(pippy_forward, model))
4 changes: 2 additions & 2 deletions pippy/microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def shard_dict_of_args(
sharded_arg_flat = []

for v, chunk_v in zip(flat, chunk_spec_flat):
if chunk_v is Replicate:
if chunk_v is Replicate or not isinstance(v, torch.Tensor):
sharded_arg_flat.append([v] * num_chunks)
elif isinstance(chunk_v, TensorChunkSpec):
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
# If it's a collection type, split it as you would expect. Otherwise,
# Throw an error
assert isinstance(v, torch.Tensor)
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"

chunk_tensors = torch.tensor_split(
v, num_chunks, chunk_v.split_dim
Expand Down
2 changes: 2 additions & 0 deletions pippy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from typing import List


# Pinning process to a separate GPU if not yet done by launch script
# Notes:
# 1. Previously this env was added to work around an issue that each RPC process creates an extra CUDA context on device
Expand All @@ -27,6 +28,7 @@
f"Pinning local process {local_rank_str} to gpu {os.getenv('CUDA_VISIBLE_DEVICES')}"
)


import torch
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
Expand Down