Skip to content

Commit

Permalink
BLOOM Inference via DeepSpeed-Inference, Accelerate and DeepSpeed-ZeRO (
Browse files Browse the repository at this point in the history
bigscience-workshop#308)

* hardcode the dtype depending on the model

* change the mp based on the world_size

* remove hardcoded world_size

* add bigscience/bigscience-small-testing

* fixes

* add zero-inference script

* fixes

* fix

* working script

* renames

* fixes

* fix for offline use

* add benchmark

* add benchmark

* update

* cleanup

* update

* msecs

* cleanup

* improve

* fix benchmark, add warmup

* update

* fix; thanks Michael Wyatt

* clarify

* add bloom batch-inference script

* removed the names :-)

* fold the bs functionality from the other script

* fix

* restore do_sample

* dump generate args

* fix

* fix

* support any batchsize

* div by bs

* mul by bs

* add cpu_offload; sync scripts

* wip

* improvements

* fixes

* fixes

* add accelerate script

* fix

* wip

* wip

* stats

* add OnDevice and remove zero-inference (bigscience-workshop#316)

* wip

* rework generate + benchmark

* figure out the memory map dynamically

* bug fix

* fix ds-zero-inference wrt device

* bug fix

* update

* update

* fix

Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
3 people authored and younesbelkada committed Sep 28, 2022
1 parent c9f196e commit e52d34c
Show file tree
Hide file tree
Showing 5 changed files with 890 additions and 153 deletions.
194 changes: 194 additions & 0 deletions scripts/inference/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,195 @@
# Inference scripts for BLOOM

## BLOOM Inference solutions

Here are some stats on JeanZay's 8x80GB A100 node w/ 512GB of CPU memory:

All benchmarks are doing greedy generation of 100 token outputs:
```
Generate args {'min_length': 100, 'max_length': 100, 'do_sample': False}
```
The inputs are just a few tokens.

Throughput in msecs:

| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 |
| :----------- | :---- | :---- | :---- | :---- | :---- | :--- |
| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | omm |
| ds-inference | 40.57 | 5.23 | | | 2.77 | 0.66 |
| ds-zero | 283 | 34.88 | oom | oom | oom | oom |


Start to ready to generate in secs:

| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 |
| :----------- | :--- | :--- | :--- | :--- | :--- | :--- |
| accelerate | 121 | 120 | 113 | 118 | | |
| ds-inference | 662 | 673 | | | 685 | 654 |
| ds-zero | 462 | 463 | | | | |
| | | | | | | |


DS-Inference load time (start to ready to generate) will become much faster soon. Once we stop relying on ds-zero to instantiate the model on gpu. The plan is to pre-shard the weights TP-wise for 8x and 16x gpus and load them directly on each gpu. Will probably be under 1min.


## Deepspeed-Inference

Tensor-Parallelism and efficient fused CUDA kernels:
https://www.deepspeed.ai/tutorials/inference-tutorial/

### Setup

```
git clone https://github.com/microsoft/DeepSpeed
cd DeepSpeed
pip install .
```

### Run

```
deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom
```

Performance on a single node of 8x80GB A100 w/ 512GB CPU RAM (JeanZay) - just a batch of 1 (would be more efficient to run a larger batch)

Adding `--benchmark` to activate the benchmarks


BS=1
```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-inference_bs=1.txt
[...]
```

While processing memory per process:

- GPU: ~50GB
- CPU: ~10GB


BS=8
```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-inference_bs=8.txt
[...]
*** Performance stats:
Throughput per token including tokenize: 5.23 msecs
Start to ready to generate: 683.397 secs
Tokenize and generate 800 (bs=8) tokens: 4.241 secs
Start to finish: 687.638 secs
```

BS=64

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 64 --benchmark 2>&1 | tee bloom-ds-inference_bs=64.txt
```

BS=128

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 128 --benchmark 2>&1 | tee bloom-ds-inference_bs=128.txt
```

## Deepspeed ZeRO-Inference

https://www.deepspeed.ai/tutorials/zero/

### Setup

```
pip install deepspeed
```


### Run

Note that the script currently runs the same inputs on all GPUs, but you can run a different stream on each GPU, and get `n_gpu` times faster throughput. You can't do that with Deepspeed-Inference.


BS=1

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt
[...]
*** Performance stats:
Throughput per token including tokenize: 282.93 msecs
Start to ready to generate: 501.871 secs
Tokenize and generate 800 (bs=1) tokens: 226.188 secs
Start to finish: 728.060 secs
```


BS=8

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt
[...]
*** Performance stats:
Throughput per token including tokenize: 34.57 msecs
Start to ready to generate: 482.132 secs
Tokenize and generate 6400 (bs=8) tokens: 221.236 secs
Start to finish: 703.368 secs
```

BS=16 and higher OOMs

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt
[...]
OOM
```



## HF Accelerate

https://github.com/huggingface/accelerate

### Setup

```
pip install transformers
```



### Run




BS=1
```
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt
[...]
```

BS=8
```
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt
[...]
```

BS=16
```
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt
[...]
```
186 changes: 186 additions & 0 deletions scripts/inference/bloom-accelerate-inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import argparse
import time
import os
import gc
import torch
import math
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
parser.add_argument("--name", type=str, help="Name path", required=True)
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark")
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--top-p", type=float, default=0.)

return parser.parse_args()

def get_max_memory_per_gpu_dict(dtype, model_name):
""" try to generate the memory map based on what we know about the model and the available hardware """

# figure out the memory map - the minimum per gpu required to load the model
n_gpus = torch.cuda.device_count()

if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30:
# hand crafted optimized memory map for 8x80 setup over BLOOM
# this works with bs=40
return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'}

try:
# model_params calculation, as we don't have a model yet to do:
#model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

config = AutoConfig.from_pretrained(model_name)
h = config.n_embed
l = config.n_layer
v = config.vocab_size
# from https://github.com/bigscience-workshop/bigscience/tree/6917a3b5fefcf439d3485ca184b4d9f6ab605150/math#model-sizing
model_params = l*(12*h**2 + 13*h) + v*h + 4*h
except:
print(f"The model {model_name} has a broken config file. Please notify the owner")
raise

bytes = torch.finfo(dtype).bits / 8
param_memory_total_in_bytes = model_params * bytes
# add 5% since weight sizes aren't the same and some GPU may need more memory
param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05)
print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights")

# check the real available memory
# load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB)
torch.ones(1).cuda()
max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0]
if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes:
raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)")

return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())}

t_start = time.time()

num_tokens = 100

args = get_args()

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

rank = local_rank

model_name = args.name
if rank == 0:
print(f"Loading model {model_name}")


tokenizer = AutoTokenizer.from_pretrained(model_name)

# XXX: can't automatically derive dtype via config's `from_pretrained`
dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16

#print(get_max_memory_per_gpu_dict())


model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
max_memory=get_max_memory_per_gpu_dict(dtype, model_name),
torch_dtype=dtype,
)


if args.benchmark:
t_ready = time.time()



### Generate

if rank == 0:
print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")

input_sentences = [
"DeepSpeed is a machine learning framework",
"He is working on",
"He has a",
"He got all",
"Everyone is happy and I can",
"The new movie that got Oscar this year",
"In the far far distance from our galaxy,",
"Peace is the only way"
]

if args.batch_size > len(input_sentences):
# dynamically extend to support larger bs by repetition
input_sentences *= math.ceil(args.batch_size / len(input_sentences))

generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False)
#generate_kwargs = dict(max_new_tokens=num_tokens, use_cache=False, do_sample=False)
#generate_kwargs = dict(min_length=num_tokens, max_length=num_tokens, do_sample=False)

if rank == 0:
print(f"Generate args {generate_kwargs}")
inputs = input_sentences[:args.batch_size]
def generate():
""" returns a list of zipped inputs, outputs and number of new tokens """

input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to("cuda:0")

outputs = model.generate(**input_tokens, **generate_kwargs)

input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
output_tokens_lengths = [x.shape[0] for x in outputs]

total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

return zip(inputs, outputs, total_new_tokens)

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
_ = generate()

t_generate_start = time.time()
generated = generate()
t_generate_span = time.time() - t_generate_start
if rank == 0:
for i,o,_ in generated:
print(f"{'-'*60}\nin={i}\nout={o}\n")


if args.benchmark:
torch.cuda.empty_cache()
gc.collect()

### Benchmark

if args.benchmark:
if rank == 0:
print(f"*** Running benchmark")

# warm up
for i in range(1):
_ = generate()
torch.cuda.synchronize()

# benchmark
t0 = time.time()
cycles = 5
total_new_tokens_generated = 0
for i in range(cycles):
generated = generate()
total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated)
torch.cuda.synchronize()
if rank == 0:
througput = (time.time() - t0)/(total_new_tokens_generated)
print(f"""
*** Performance stats:
Throughput per token including tokenize: {througput*1000:.2f} msecs
Start to ready to generate: {t_ready - t_start:.3f} secs
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs
""")

0 comments on commit e52d34c

Please sign in to comment.