diff --git a/bloom-inference-scripts/bloom-ds-inference.py b/bloom-inference-scripts/bloom-ds-inference.py index 4bed6a2..3031426 100644 --- a/bloom-inference-scripts/bloom-ds-inference.py +++ b/bloom-inference-scripts/bloom-ds-inference.py @@ -44,7 +44,7 @@ parser = ArgumentParser() parser.add_argument("--name", required=True, type=str, help="model_name") -parser.add_argument("--dtype", type=str, help="float16 or int8", choices=["int8", "float16"], default="float16") +parser.add_argument("--dtype", type=str, help="float16 or int8 or int4", choices=["int8", "float16", "int4"], default="float16") parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") parser.add_argument("--batch_size", default=1, type=int, help="batch size") parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") @@ -100,7 +100,7 @@ def get_checkpoint_files(model_name_or_path): model_name = args.name -infer_dtype = args.dtype +infer_dtype = args.dtype if args.dtype != 'int4' else 'int8' tp_presharded_mode = True if model_name in tp_presharded_models else False @@ -171,7 +171,19 @@ def write_checkponts_json(): deepspeed.runtime.utils.see_memory_usage("pre-ds-inference-init", force=True) if kernel_inject: - kwargs = dict(replace_with_kernel_inject=True) + if args.dtype == 'int8': + bits = 4 + if args.dtype == 'int4': + bits = 8 + ds_config = { + "replace_with_kernel_inject" : True, + "quant" : { + "enabled" : True, + "weight" : { + "num_bits" : bits + } + } + } else: kwargs = dict(injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}) @@ -188,6 +200,7 @@ def write_checkponts_json(): # checkpoints_json=None model = deepspeed.init_inference( model, + config=ds_config, mp_size=world_size, base_dir=repo_root, dtype=getattr(torch, infer_dtype),