diff --git a/bloom-inference-scripts/bloom-ds-inference.py b/bloom-inference-scripts/bloom-ds-inference.py
index 4bed6a2..3031426 100644
--- a/bloom-inference-scripts/bloom-ds-inference.py
+++ b/bloom-inference-scripts/bloom-ds-inference.py
@@ -44,7 +44,7 @@
 parser = ArgumentParser()
 
 parser.add_argument("--name", required=True, type=str, help="model_name")
-parser.add_argument("--dtype", type=str, help="float16 or int8", choices=["int8", "float16"], default="float16")
+parser.add_argument("--dtype", type=str, help="float16 or int8 or int4", choices=["int8", "float16", "int4"], default="float16")
 parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
 parser.add_argument("--batch_size", default=1, type=int, help="batch size")
 parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark")
@@ -100,7 +100,7 @@ def get_checkpoint_files(model_name_or_path):
 
 
 model_name = args.name
-infer_dtype = args.dtype
+infer_dtype = args.dtype if args.dtype != 'int4' else 'int8'
 
 tp_presharded_mode = True if model_name in tp_presharded_models else False
 
@@ -171,7 +171,19 @@ def write_checkponts_json():
     deepspeed.runtime.utils.see_memory_usage("pre-ds-inference-init", force=True)
 
 if kernel_inject:
-    kwargs = dict(replace_with_kernel_inject=True)
+    if args.dtype == 'int8':
+        bits = 4
+    if args.dtype == 'int4':
+        bits = 8
+    ds_config = {
+        "replace_with_kernel_inject" : True,
+        "quant" : {
+            "enabled" : True,
+            "weight" : {
+                "num_bits" : bits
+            }
+        }
+    }
 else:
     kwargs = dict(injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")})
 
@@ -188,6 +200,7 @@ def write_checkponts_json():
 # checkpoints_json=None
 model = deepspeed.init_inference(
     model,
+    config=ds_config,
     mp_size=world_size,
     base_dir=repo_root,
     dtype=getattr(torch, infer_dtype),