diff --git a/scripts/inference/README.md b/scripts/inference/README.md index 1a958c28b..44e98f9fb 100644 --- a/scripts/inference/README.md +++ b/scripts/inference/README.md @@ -1 +1,195 @@ # Inference scripts for BLOOM + +## BLOOM Inference solutions + +Here are some stats on JeanZay's 8x80GB A100 node w/ 512GB of CPU memory: + +All benchmarks are doing greedy generation of 100 token outputs: +``` +Generate args {'min_length': 100, 'max_length': 100, 'do_sample': False} +``` +The inputs are just a few tokens. + +Throughput in msecs: + +| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 | +| :----------- | :---- | :---- | :---- | :---- | :---- | :--- | +| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | omm | +| ds-inference | 40.57 | 5.23 | | | 2.77 | 0.66 | +| ds-zero | 283 | 34.88 | oom | oom | oom | oom | + + +Start to ready to generate in secs: + +| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 | +| :----------- | :--- | :--- | :--- | :--- | :--- | :--- | +| accelerate | 121 | 120 | 113 | 118 | | | +| ds-inference | 662 | 673 | | | 685 | 654 | +| ds-zero | 462 | 463 | | | | | +| | | | | | | | + + +DS-Inference load time (start to ready to generate) will become much faster soon. Once we stop relying on ds-zero to instantiate the model on gpu. The plan is to pre-shard the weights TP-wise for 8x and 16x gpus and load them directly on each gpu. Will probably be under 1min. + + +## Deepspeed-Inference + +Tensor-Parallelism and efficient fused CUDA kernels: +https://www.deepspeed.ai/tutorials/inference-tutorial/ + +### Setup + +``` +git clone https://github.com/microsoft/DeepSpeed +cd DeepSpeed +pip install . +``` + +### Run + +``` +deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom +``` + +Performance on a single node of 8x80GB A100 w/ 512GB CPU RAM (JeanZay) - just a batch of 1 (would be more efficient to run a larger batch) + +Adding `--benchmark` to activate the benchmarks + + +BS=1 +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-inference_bs=1.txt +[...] + +``` + +While processing memory per process: + +- GPU: ~50GB +- CPU: ~10GB + + +BS=8 +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-inference_bs=8.txt +[...] +*** Performance stats: +Throughput per token including tokenize: 5.23 msecs +Start to ready to generate: 683.397 secs +Tokenize and generate 800 (bs=8) tokens: 4.241 secs +Start to finish: 687.638 secs +``` + +BS=64 + +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 64 --benchmark 2>&1 | tee bloom-ds-inference_bs=64.txt + + + + +``` + +BS=128 + +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 128 --benchmark 2>&1 | tee bloom-ds-inference_bs=128.txt + + + + +``` + +## Deepspeed ZeRO-Inference + +https://www.deepspeed.ai/tutorials/zero/ + +### Setup + +``` +pip install deepspeed +``` + + +### Run + +Note that the script currently runs the same inputs on all GPUs, but you can run a different stream on each GPU, and get `n_gpu` times faster throughput. You can't do that with Deepspeed-Inference. + + +BS=1 + +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt +[...] +*** Performance stats: +Throughput per token including tokenize: 282.93 msecs +Start to ready to generate: 501.871 secs +Tokenize and generate 800 (bs=1) tokens: 226.188 secs +Start to finish: 728.060 secs +``` + + +BS=8 + +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt +[...] + +*** Performance stats: +Throughput per token including tokenize: 34.57 msecs +Start to ready to generate: 482.132 secs +Tokenize and generate 6400 (bs=8) tokens: 221.236 secs +Start to finish: 703.368 secs +``` + +BS=16 and higher OOMs + +``` +$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt +[...] +OOM + +``` + + + +## HF Accelerate + +https://github.com/huggingface/accelerate + +### Setup + +``` +pip install transformers +``` + + + +### Run + + + + +BS=1 +``` +$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt +[...] + + +``` + +BS=8 +``` +$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt +[...] + + +``` + +BS=16 +``` +$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt +[...] + + +``` diff --git a/scripts/inference/bloom-accelerate-inference.py b/scripts/inference/bloom-accelerate-inference.py new file mode 100644 index 000000000..415b2f765 --- /dev/null +++ b/scripts/inference/bloom-accelerate-inference.py @@ -0,0 +1,186 @@ +import argparse +import time +import os +import gc +import torch +import math +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") + parser.add_argument("--name", type=str, help="Name path", required=True) + parser.add_argument("--batch_size", default=1, type=int, help="batch size") + parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") + parser.add_argument("--greedy", action="store_true") + parser.add_argument("--top-k", type=int, default=0) + parser.add_argument("--top-p", type=float, default=0.) + + return parser.parse_args() + +def get_max_memory_per_gpu_dict(dtype, model_name): + """ try to generate the memory map based on what we know about the model and the available hardware """ + + # figure out the memory map - the minimum per gpu required to load the model + n_gpus = torch.cuda.device_count() + + if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30: + # hand crafted optimized memory map for 8x80 setup over BLOOM + # this works with bs=40 + return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'} + + try: + # model_params calculation, as we don't have a model yet to do: + #model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + + config = AutoConfig.from_pretrained(model_name) + h = config.n_embed + l = config.n_layer + v = config.vocab_size + # from https://github.com/bigscience-workshop/bigscience/tree/6917a3b5fefcf439d3485ca184b4d9f6ab605150/math#model-sizing + model_params = l*(12*h**2 + 13*h) + v*h + 4*h + except: + print(f"The model {model_name} has a broken config file. Please notify the owner") + raise + + bytes = torch.finfo(dtype).bits / 8 + param_memory_total_in_bytes = model_params * bytes + # add 5% since weight sizes aren't the same and some GPU may need more memory + param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05) + print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights") + + # check the real available memory + # load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB) + torch.ones(1).cuda() + max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0] + if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes: + raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)") + + return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())} + +t_start = time.time() + +num_tokens = 100 + +args = get_args() + +local_rank = int(os.getenv('LOCAL_RANK', '0')) +world_size = int(os.getenv('WORLD_SIZE', '1')) + +rank = local_rank + +model_name = args.name +if rank == 0: + print(f"Loading model {model_name}") + + +tokenizer = AutoTokenizer.from_pretrained(model_name) + +# XXX: can't automatically derive dtype via config's `from_pretrained` +dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 + +#print(get_max_memory_per_gpu_dict()) + + +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + max_memory=get_max_memory_per_gpu_dict(dtype, model_name), + torch_dtype=dtype, +) + + +if args.benchmark: + t_ready = time.time() + + + +### Generate + +if rank == 0: + print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") + +input_sentences = [ + "DeepSpeed is a machine learning framework", + "He is working on", + "He has a", + "He got all", + "Everyone is happy and I can", + "The new movie that got Oscar this year", + "In the far far distance from our galaxy,", + "Peace is the only way" +] + +if args.batch_size > len(input_sentences): + # dynamically extend to support larger bs by repetition + input_sentences *= math.ceil(args.batch_size / len(input_sentences)) + +generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) +#generate_kwargs = dict(max_new_tokens=num_tokens, use_cache=False, do_sample=False) +#generate_kwargs = dict(min_length=num_tokens, max_length=num_tokens, do_sample=False) + +if rank == 0: + print(f"Generate args {generate_kwargs}") +inputs = input_sentences[:args.batch_size] +def generate(): + """ returns a list of zipped inputs, outputs and number of new tokens """ + + input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to("cuda:0") + + outputs = model.generate(**input_tokens, **generate_kwargs) + + input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] + output_tokens_lengths = [x.shape[0] for x in outputs] + + total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)] + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return zip(inputs, outputs, total_new_tokens) + +# warmup is a must if measuring speed as it's when all the optimizations are performed +# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs +_ = generate() + +t_generate_start = time.time() +generated = generate() +t_generate_span = time.time() - t_generate_start +if rank == 0: + for i,o,_ in generated: + print(f"{'-'*60}\nin={i}\nout={o}\n") + + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + +### Benchmark + +if args.benchmark: + if rank == 0: + print(f"*** Running benchmark") + + # warm up + for i in range(1): + _ = generate() + torch.cuda.synchronize() + + # benchmark + t0 = time.time() + cycles = 5 + total_new_tokens_generated = 0 + for i in range(cycles): + generated = generate() + total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated) + torch.cuda.synchronize() + if rank == 0: + througput = (time.time() - t0)/(total_new_tokens_generated) + print(f""" +*** Performance stats: +Throughput per token including tokenize: {througput*1000:.2f} msecs +Start to ready to generate: {t_ready - t_start:.3f} secs +Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs +Start to finish: {t_ready - t_start + t_generate_span:.3f} secs +""") diff --git a/scripts/inference/bloom-ds-inference.py b/scripts/inference/bloom-ds-inference.py new file mode 100644 index 000000000..c21dfeb96 --- /dev/null +++ b/scripts/inference/bloom-ds-inference.py @@ -0,0 +1,299 @@ +# usage: +# deepspeed --num_gpus 8 bloom-ds-inference.py --name bigscience/bloom +# +# to run benchmarks: +# deepspeed --num_gpus 8 bloom-ds-inference.py --name bigscience/bloom --benchmark +# + + +# This is going to improve, but at the moment, the process is a bit cumbersome - we first use +# 1. use Deepspeed-ZeRO to instantiate the model on GPUs, w/o loading the checkpoints, +# 2. free the allocated storage +# 3. start Deepspeed-Inference and only now load the checkpoint +# 4. run generate +# Done. +# + + +from argparse import ArgumentParser +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers.deepspeed import HfDeepSpeedConfig +from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock +import deepspeed +import gc +import glob +import io +import json +import math +import os +import sys +import time +import torch +import torch.distributed as dist + +t_start = time.time() + +num_tokens = 100 + +parser = ArgumentParser() + +parser.add_argument("--name", required=True, type=str, help="model_name") +parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") +parser.add_argument("--batch_size", default=1, type=int, help="batch size") +parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") +args = parser.parse_args() + +local_rank = int(os.getenv('LOCAL_RANK', '0')) +world_size = int(os.getenv('WORLD_SIZE', '1')) + +deepspeed.init_distributed('nccl') +rank = dist.get_rank() + + +### Model loading and instantiating on GPUs + +def get_checkpoint_files(pretrained_model_name_or_path): + # XXX: I just hacked this one together to automatically handle the fetching of the model file or + # shards into cache and returning the cached entries - note that I removed most arguments + + from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, cached_path, hf_bucket_url, is_offline_mode + from transformers.utils.hub import EntryNotFoundError + from transformers.modeling_utils import get_checkpoint_shard_files + + cache_dir = None + is_sharded = False + + # XXX: preparation for revision branches if needed + revision = None + #revision = "sharded" + + # this supports nodes with no network (so you need to pre-cache the model and the tokenizer with + # python -c "from transformers import AutoModel; AutoModel.from_pretrained('bigscience/bloom')" + if is_offline_mode(): + print("Offline mode: forcing local_files_only=True") + local_files_only = True + else: + local_files_only = False + + filename = WEIGHTS_NAME + archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=filename, revision=revision) + + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, local_files_only=local_files_only,) + return [resolved_archive_file] + + except (EntryNotFoundError, FileNotFoundError): + if filename == WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + ) + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + is_sharded = True + + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + revision=revision + ) + + return resolved_archive_file + +model_name = args.name + +#print(get_checkpoint_files(model_name)) + +if rank == 0: + print(f"*** Loading the model {model_name}") + +tokenizer = AutoTokenizer.from_pretrained(model_name) +config = AutoConfig.from_pretrained(model_name) + +# XXX: can't automatically derive dtype via config's `from_pretrained` +#dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 + + +# use one of these args to `init_inference` +# 1. injection_policy is the slower version, but it's plain pytorch so it'll always work +# 2. replace_with_kernel_inject is the faster one (fast fused kernels) +kernel_inject = True +#kernel_inject = False + +if kernel_inject: + # XXX: for now ds-inference only works with fp16 + dtype = torch.float16 +else: + dtype = torch.bfloat16 + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('pre-from-pretrained', force=True) + +# Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load +with deepspeed.OnDevice(dtype=dtype, device='meta'): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) + +if args.benchmark: + deepspeed.runtime.utils.see_memory_usage('post-from-pretrained', force=True) + +model = model.eval() + + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('post-init-ds-zero-init', force=True) + +### Deepspeed-Inference Loading + +checkpoints_json = "checkpoints.json" +def write_checkponts_json(): + + with io.open(checkpoints_json, 'w', encoding='utf-8') as f: + + #checkpoint_dir = "/gpfsscratch/rech/six/commun/uan68tv-model-conversion/bloom" + #checkpoint_files = glob.glob(f"{checkpoint_dir}/*bin") + checkpoint_files = get_checkpoint_files(model_name) + + #print("Checkpoint files:", checkpoint_files) + + data = { + "type": "BLOOM-176B", + "checkpoints": checkpoint_files, + "version": 1.0 + } + json.dump(data, f) + +if rank == 0: + write_checkponts_json() +dist.barrier() + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('pre-ds-inference-init', force=True) + +if kernel_inject: + kwargs = dict(replace_with_kernel_inject=True) +else: + kwargs = dict(injection_policy={BloomBlock: ('self_attention.dense', 'mlp.dense_4h_to_h')}) + +#checkpoints_json=None +model = deepspeed.init_inference(model, + mp_size=world_size, + dtype=torch.half, + checkpoint=checkpoints_json, + **kwargs, + ) + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('post-ds-inference-init', force=True) + + +model = model.module + +if args.benchmark: + t_ready = time.time() + + +### Generate + +if rank == 0: + print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") + +input_sentences = [ + "DeepSpeed is a machine learning framework", + "He is working on", + "He has a", + "He got all", + "Everyone is happy and I can", + "The new movie that got Oscar this year", + "In the far far distance from our galaxy,", + "Peace is the only way" +] + +if args.batch_size > len(input_sentences): + # dynamically extend to support larger bs by repetition + input_sentences *= math.ceil(args.batch_size / len(input_sentences)) + +generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) + +if rank == 0: + print(f"Generate args {generate_kwargs}") +inputs = input_sentences[:args.batch_size] +def generate(): + """ returns a list of zipped inputs, outputs and number of new tokens """ + + input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + + outputs = model.generate(**input_tokens, **generate_kwargs) + + input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] + output_tokens_lengths = [x.shape[0] for x in outputs] + + total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)] + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return zip(inputs, outputs, total_new_tokens) + + +# warmup is a must if measuring speed as it's when all the optimizations are performed +# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs +_ = generate() + +t_generate_start = time.time() +generated = generate() +t_generate_span = time.time() - t_generate_start +if rank == 0: + for i,o,_ in generated: + print(f"{'-'*60}\nin={i}\nout={o}\n") + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('end-of-run', force=True) + +### Benchmark + +# benchmark it! +if args.benchmark: + if rank == 0: + print(f"*** Running benchmark") + + # warm up + for i in range(1): + _ = generate() + torch.cuda.synchronize() + + # benchmark + t0 = time.time() + cycles = 5 + total_new_tokens_generated = 0 + for i in range(cycles): + generated = generate() + total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated) + torch.cuda.synchronize() + if rank == 0: + througput = (time.time() - t0)/(total_new_tokens_generated) + print(f""" +*** Performance stats: +Throughput per token including tokenize: {througput*1000:.2f} msecs +Start to ready to generate: {t_ready - t_start:.3f} secs +Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs +Start to finish: {t_ready - t_start + t_generate_span:.3f} secs +""") diff --git a/scripts/inference/bloom-ds-zero-inference.py b/scripts/inference/bloom-ds-zero-inference.py new file mode 100644 index 000000000..043b4967f --- /dev/null +++ b/scripts/inference/bloom-ds-zero-inference.py @@ -0,0 +1,211 @@ +# usage: +# deepspeed --num_gpus 8 bloom-ds-inference.py --name bigscience/bloom +# +# to run benchmarks: +# deepspeed --num_gpus 8 bloom-ds-inference.py --name bigscience/bloom --benchmark +# + + +# This is going to improve, but at the moment, the process is a bit cumbersome - we first use +# 1. use Deepspeed-ZeRO to instantiate the model on GPUs, w/o loading the checkpoints, +# 2. free the allocated storage +# 3. start Deepspeed-Inference and only now load the checkpoint +# 4. run generate +# Done. +# + + +from argparse import ArgumentParser +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers.deepspeed import HfDeepSpeedConfig +from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock +import deepspeed +import gc +import glob +import io +import json +import math +import os +import sys +import time +import torch +import torch.distributed as dist + +t_start = time.time() + +num_tokens = 100 + +parser = ArgumentParser() + +parser.add_argument("--name", required=True, type=str, help="model_name") +parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") +parser.add_argument("--batch_size", default=1, type=int, help="batch size") +parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") +parser.add_argument("--cpu_offload", action="store_true", help="whether to activate CPU offload") +args = parser.parse_args() + +local_rank = int(os.getenv('LOCAL_RANK', '0')) +world_size = int(os.getenv('WORLD_SIZE', '1')) + + +### Model loading and instantiating on GPU (via ZeRO) + +model_name = args.name + +if local_rank == 0: + print(f"*** Loading the model {model_name}") + +tokenizer = AutoTokenizer.from_pretrained(model_name) +config = AutoConfig.from_pretrained(model_name) + +# XXX: can't automatically derive dtype via config's `from_pretrained` +dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 + +model_hidden_size = config.hidden_size +train_batch_size = 1 * world_size + +ds_config = { + "fp16": { + "enabled": dtype == torch.float16, + }, + "bf16": { + "enabled": dtype == torch.bfloat16, + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": True, + "contiguous_gradients": True, + "reduce_bucket_size": model_hidden_size * model_hidden_size, + "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size, + "stage3_param_persistence_threshold": 0 + }, + "steps_per_print": 2000, + "train_batch_size": train_batch_size, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": False +} + +if args.cpu_offload: + ds_config["zero_optimization"]["offload_param"] = dict(device="cpu", pin_memory=True) + +dschf = HfDeepSpeedConfig(ds_config) # this tells from_pretrained to instantiate directly on gpus + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('pre-from-pretrained', force=True) + +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) + +if args.benchmark: + deepspeed.runtime.utils.see_memory_usage('post-from-pretrained', force=True) + +model = model.eval() + +rank = dist.get_rank() + +if rank == 0: + print(ds_config) + +ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] +ds_engine.module.eval() +model = ds_engine.module + +if args.benchmark: + t_ready = time.time() + + +### Generate + +if rank == 0: + print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") + +input_sentences = [ + "DeepSpeed is a machine learning framework", + "He is working on", + "He has a", + "He got all", + "Everyone is happy and I can", + "The new movie that got Oscar this year", + "In the far far distance from our galaxy,", + "Peace is the only way" +] + +if args.batch_size > len(input_sentences): + # dynamically extend to support larger bs by repetition + input_sentences *= math.ceil(args.batch_size / len(input_sentences)) + +generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) + +if rank == 0: + print(f"Generate args {generate_kwargs}") +inputs = input_sentences[:args.batch_size] +def generate(): + """ returns a list of zipped inputs, outputs and number of new tokens """ + + input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + + outputs = model.generate(**input_tokens, **generate_kwargs) + + input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] + output_tokens_lengths = [x.shape[0] for x in outputs] + + total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)] + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return zip(inputs, outputs, total_new_tokens) + +# XXX: this is currently doing world_size streams on world_size gpus, so we can feed it different inputs on each! and hence the time can be divided by world_size + +# warmup is a must if measuring speed as it's when all the optimizations are performed +# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs +_ = generate() + +t_generate_start = time.time() +pairs = generate() +t_generate_span = time.time() - t_generate_start +if rank == 0: + for i,o,_ in pairs: + print(f"{'-'*60}\nin={i}\nout={o}\n") + + +if args.benchmark: + torch.cuda.empty_cache() + gc.collect() + deepspeed.runtime.utils.see_memory_usage('end-of-run', force=True) + +### Benchmark + +if args.benchmark: + if rank == 0: + print(f"*** Running benchmark") + + # warm up + for i in range(1): + _ = generate() + torch.cuda.synchronize() + + # benchmark + t0 = time.time() + cycles = 5 + total_new_tokens_generated = 0 + for i in range(cycles): + generated = generate() + total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated) + + torch.cuda.synchronize() + if rank == 0: + # note that we actually generate world_size unique streams (though the benchmark feeds the same inputs) + total_new_tokens_generated *= world_size + througput = (time.time() - t0)/(total_new_tokens_generated) + print(f""" +*** Performance stats: +Throughput per token including tokenize: {througput*1000:.2f} msecs +Start to ready to generate: {t_ready - t_start:.3f} secs +Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs +Start to finish: {t_ready - t_start + t_generate_span:.3f} secs +""") + diff --git a/scripts/inference/bloom-inference.py b/scripts/inference/bloom-inference.py deleted file mode 100644 index 17da46795..000000000 --- a/scripts/inference/bloom-inference.py +++ /dev/null @@ -1,153 +0,0 @@ - -# usage: -# deepspeed --num_gpus 1 bloom-inference.py --name bigscience/bloom-350m -# - -#import glob -from argparse import ArgumentParser -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig -from transformers.deepspeed import HfDeepSpeedConfig -from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock -import deepspeed -import io -import json -import os -import torch -import torch.distributed as dist - -parser = ArgumentParser() - -parser.add_argument("--name", required=True, type=str) -parser.add_argument("--local_rank", required=False, type=int) -parser.add_argument("--deepspeed", action="store_true") -args = parser.parse_args() - -local_rank = int(os.getenv('LOCAL_RANK', '0')) -world_size = int(os.getenv('WORLD_SIZE', '1')) - -def get_checkpoint_files(pretrained_model_name_or_path): - # XXX: I just hacked this one together to automatically handle the fetching of the model file or - # shards into cache and returning the cached entries - note that I removed most arguments - - from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, cached_path, hf_bucket_url - - cache_dir = None - is_sharded = False - filename = WEIGHTS_NAME - archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=filename) - - try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - return [resolved_archive_file] - - except EntryNotFoundError: - if filename == WEIGHTS_NAME: - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - archive_file = hf_bucket_url( - pretrained_model_name_or_path, - filename=WEIGHTS_INDEX_NAME, - ) - resolved_archive_file = cached_path( - archive_file, - cache_dir=cache_dir, - ) - is_sharded = True - - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, - resolved_archive_file, - cache_dir=cache_dir, - ) - - return resolved_archive_file - - -model_name = args.name - -tokenizer = AutoTokenizer.from_pretrained(model_name) -config = AutoConfig.from_pretrained(model_name) -model_hidden_size = config.hidden_size -train_batch_size = 1 * world_size -model = AutoModelForCausalLM.from_config(config) - -ds_config = { - "fp16": { - "enabled": model.dtype == torch.float16, - }, - "bf16": { - "enabled": model.dtype == torch.bfloat16, - }, - "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "cpu", - "pin_memory": True - }, - "overlap_comm": True, - "contiguous_gradients": True, - "reduce_bucket_size": model_hidden_size * model_hidden_size, - "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size, - "stage3_param_persistence_threshold": 0 - }, - "steps_per_print": 2000, - "train_batch_size": train_batch_size, - "train_micro_batch_size_per_gpu": 1, - "wall_clock_breakdown": False -} - -dschf = HfDeepSpeedConfig(ds_config) - -model = model.eval() -ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] -ds_engine.module.eval() -model = ds_engine.module - - - -checkpoints_json = "checkpoints.json" -with io.open(checkpoints_json, 'w', encoding='utf-8') as f: - - #checkpoint_files = glob.glob(f"args.checkpoint_dir/*bin") - checkpoint_files = get_checkpoint_files(model_name) - - print("Checkpoint files:", checkpoint_files) - - data = { - "type": "BLOOM-176B", - "checkpoints": checkpoint_files, - "version": 1.0 - } - json.dump(data, f) - - -model = deepspeed.init_inference(model, - mp_size=1, - dtype=torch.half, - checkpoint=checkpoints_json, - #injection_policy={BloomBlock: ('self_attention.dense', 'mlp.dense_4h_to_h')} - replace_with_kernel_inject=True - ) -model = model.module - -text_in = 'DeepSpeed is' - -tokens = tokenizer(text_in, return_tensors="pt") - -for t in tokens: - if torch.is_tensor(tokens[t]): - tokens[t] = tokens[t].to(torch.cuda.current_device()) - -with torch.no_grad(): - gen_tokens = model.generate( - **tokens, - min_length=50, - max_length=50, - do_sample=False, - ) - - -text_out = tokenizer.batch_decode(gen_tokens)[0] - -print(f"in={text_in}\nout={text_out}")