diff --git a/examples/inference/opt_generate.py b/examples/inference/opt_generate.py new file mode 100644 index 000000000..f49a2785a --- /dev/null +++ b/examples/inference/opt_generate.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os + +import torch +import pippy +import pippy.fx +from pippy import run_pippy +from pippy.hf import PiPPyHFTracer, inject_pipeline_forward +from transformers import AutoTokenizer, OPTForCausalLM + + +pippy.fx.Tracer.proxy_buffer_attributes = True + +gigabyte_size = 1024 ** 3 + + +def format_to_gb(item, precision=4): + """quick function to format numbers to gigabyte and round to (default) 4 digit precision""" + metric_num = item / gigabyte_size + metric_num = round(metric_num, ndigits=precision) + return metric_num + + +def print_mem_usage(): + memory_reserved = format_to_gb(torch.cuda.memory_reserved()) + memory_allocated = format_to_gb(torch.cuda.memory_allocated()) + print( + f"memory_reserved: {memory_reserved} GB, " + f"memory_allocated: {memory_allocated} GB" + ) + + +def get_number_of_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def generate_input(args): + bs = args.batch_size * args.chunks + seq_length = args.seq_length + model_config = args.model.config + torch.manual_seed(args.rank) + + inp = torch.empty(bs, seq_length, dtype=torch.long, device=args.device).random_(model_config.vocab_size) + model_input_dict = { + "input_ids": inp, + } + + return model_input_dict + + +def run_all(pp_ranks, args): + model = args.model + model.to(args.device) + model.eval() + model.config.use_cache = False # don't output `past_key_values` + num_ranks = len(pp_ranks) + + if args.rank == 0: + print(model.config) + print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M") + + split_policy = pippy.split_into_equal_size(num_ranks) + + model_input_dict = generate_input(args) + # Use default value for other kwargs than those in `model_input_dict` + concrete_args = pippy.create_default_args( + model, + except_keys=model_input_dict.keys(), + ) + + pipe_driver, stage_mod = pippy.all_compile( + model, + num_ranks, + args.chunks, + split_policy=split_policy, + tracer=PiPPyHFTracer(), + concrete_args=concrete_args, + ) + + params = get_number_of_params(stage_mod) + print(f"submod_{args.rank} {params // 10 ** 6}M params") + + if args.rank != 0: + return + + # Master continues + print_mem_usage() + + # Inject pipeline driver's forward function back to original model to support HF's `generate()` method + inject_pipeline_forward(model, pipe_driver) + + # OPT generate + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + prompt = "Hey, are you consciours? Can you talk to me?" + input = tokenizer(prompt, return_tensors="pt") + + input_ids = input["input_ids"].to(args.device) + outputs = model.generate(input_ids, max_length=30) + response = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(response) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--model_name', type=str, default='facebook/opt-350m') + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--chunks', type=int, default=1) + parser.add_argument('--seq_length', type=int, default=16) + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument('--pp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + + args = parser.parse_args() + + assert args.world_size % args.pp_group_size == 0 + + # Main process loads model + print(f"Loading model {args.model_name}") + if 'opt' in args.model_name: + model = OPTForCausalLM.from_pretrained(args.model_name, use_cache=False) + else: + raise ValueError(f"Unsupported model: {args.model_name}") + args.model = model + + args.gspmd = 1 + run_pippy(run_all, args) diff --git a/pippy/PipelineDriver.py b/pippy/PipelineDriver.py index a32716579..60ee5aa14 100644 --- a/pippy/PipelineDriver.py +++ b/pippy/PipelineDriver.py @@ -1697,8 +1697,22 @@ def __init__( node.name, Parameter.POSITIONAL_OR_KEYWORD, default=default ) ) + + # We are building a safety net here in case user passes in extra arguments than those defined as variable + # arguments (i.e. non-concrete args) at the tracing phase + # TODO: Remove this safety net + traced_args = [p.name for p in parameters] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in traced_args} + if len(filtered_kwargs) != len(kwargs): + extra_args = kwargs.keys() - filtered_kwargs.keys() + warnings.warn( + f"Received extra arguments: {extra_args}. " + f"They might have already been given a concrete value during pipeline compilation via `concrete_args`. " + f"We will ignore the current inputs and use the values given during compilation." + ) + sig = Signature(parameters) - bound_args = sig.bind(*args, **kwargs) + bound_args = sig.bind(*args, **filtered_kwargs) bound_args.apply_defaults() self.args = bound_args.args self.args_iter = iter(self.args) diff --git a/pippy/hf/__init__.py b/pippy/hf/__init__.py index 8f1df65f7..39e613001 100644 --- a/pippy/hf/__init__.py +++ b/pippy/hf/__init__.py @@ -5,6 +5,7 @@ PiPPySeq2SeqTrainingArguments, PiPPyTrainer, PiPPySeq2SeqTrainer, + inject_pipeline_forward, ) __all__ = [ @@ -13,4 +14,5 @@ "PiPPySeq2SeqTrainingArguments", "PiPPyTrainer", "PiPPySeq2SeqTrainer", + "inject_pipeline_forward", ] diff --git a/pippy/hf/utils.py b/pippy/hf/utils.py index f1822fc8c..f05b7abb9 100644 --- a/pippy/hf/utils.py +++ b/pippy/hf/utils.py @@ -22,6 +22,8 @@ ) from transformers.utils import cached_property +from pippy.PipelineDriver import PipelineDriverBase + logger = logging.getLogger(__name__) @@ -283,3 +285,45 @@ def trace(self, *args, **kwargs): elif getattr(node.target, "_orig", None) == torch.zeros: node.target = torch_zeros_wrapper return graph + + +# The `DotDict` class adds dot notation access to dictionary attributes. +class DotDict(dict): + def __getattr__(self, attr): + return self.get(attr) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + def __delattr__(self, item): + self.__delitem__(item) + + +# This is an experimental utility function that replaces the original model's forward method with PiPPy's PipelineDriver +# forward method. It is used to support HuggingFace's `generate()` method, which is defined in a `GenerationMixin` +# class that `PreTrainedModel` inherits from. We choose this replacement path instead of writing our own `generate()` +# method because the `generate()` method would call into many `GenerationMixin` APIs that may be implemented differently +# by each model. +def inject_pipeline_forward( + model: torch.nn.Module, + pipe_driver: PipelineDriverBase, +): + logging.info( + f"Inserting PiPPy pipeline forward into model {model._get_name()}" + ) + # Inject pipeline driver as a member object of original model + setattr(model, "pippy_pipeline_driver", pipe_driver) + + # Define a new forward method that uses PiPPy's pipeline driver + def pippy_forward(self, *args, **kwargs): + output = self.pippy_pipeline_driver(*args, **kwargs) + if isinstance(output, dict): + # Add dot access if output is a dictionary. The output of a traced HF model is a traditional dict which has + # only [`key`] access. The wrapping is needed for Transformer versons >= 4.28 which access attributes of + # output via dot notation, such as `output.logits`. See for example the `generate()` method and + # `modeling_output.py`. + output = DotDict(output) + return output + + # Replace the forward method in original model + setattr(model, "forward", types.MethodType(pippy_forward, model)) diff --git a/pippy/microbatch.py b/pippy/microbatch.py index ff69a7232..bfd018a90 100644 --- a/pippy/microbatch.py +++ b/pippy/microbatch.py @@ -76,13 +76,13 @@ def shard_dict_of_args( sharded_arg_flat = [] for v, chunk_v in zip(flat, chunk_spec_flat): - if chunk_v is Replicate: + if chunk_v is Replicate or not isinstance(v, torch.Tensor): sharded_arg_flat.append([v] * num_chunks) elif isinstance(chunk_v, TensorChunkSpec): # TODO: check type of v. If it's a tensor, use chunk (or debug mask). # If it's a collection type, split it as you would expect. Otherwise, # Throw an error - assert isinstance(v, torch.Tensor) + assert isinstance(v, torch.Tensor), f"{v} is not a tensor" chunk_tensors = torch.tensor_split( v, num_chunks, chunk_v.split_dim diff --git a/pippy/utils.py b/pippy/utils.py index bda89eadf..c9db5ad14 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -4,6 +4,7 @@ import logging from typing import List + # Pinning process to a separate GPU if not yet done by launch script # Notes: # 1. Previously this env was added to work around an issue that each RPC process creates an extra CUDA context on device @@ -27,6 +28,7 @@ f"Pinning local process {local_rank_str} to gpu {os.getenv('CUDA_VISIBLE_DEVICES')}" ) + import torch import torch.multiprocessing as mp import torch.distributed.rpc as rpc