diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 06e11dd6b84..cdcd5f89635 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -322,8 +322,8 @@ def build_args_parser() -> argparse.ArgumentParser: default="fp32", type=str, choices=["fp32", "fp16", "bf16"], - help="Override the dtype of the model (default is the checkpoint dtype)." - "Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.", + help="Provide the dtype of the model. This must match up with the supported dtypes of the backends that you are using." + "Please be aware that only some backends support fp16 and bf16.", ) parser.add_argument( @@ -565,43 +565,42 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: output_dir_path = canonical_path(args.output_dir, dir=True) weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA - # dtype override - if args.dtype_override is not None: - dtype_override = DType[args.dtype_override] - elif args.quantization_mode in ["8da4w", "8da4w-gptq"]: - dtype_override = DType["fp16"] - else: - dtype_override = None + # Convert dtype override string arg to actual type. + dtype_override = DType[args.dtype_override] + + edge_manager = _load_llama_model( + args.model, + checkpoint=checkpoint_path, + checkpoint_dir=checkpoint_dir, + params_path=params_path, + use_kv_cache=args.use_kv_cache, + use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, + generate_full_logits=args.generate_full_logits, + weight_type=weight_type, + enable_dynamic_shape=args.enable_dynamic_shape, + calibration_tasks=args.calibration_tasks, + calibration_limit=args.calibration_limit, + calibration_seq_length=args.calibration_seq_length, + calibration_data=args.calibration_data, + tokenizer_path=args.tokenizer_path, + verbose=args.verbose, + max_seq_len=args.max_seq_length, + max_context_len=args.max_context_length, + input_prune_map_path=args.input_prune_map, + output_prune_map_path=args.output_prune_map, + metadata_str=args.metadata, + dtype_override=dtype_override, + args=args, + ) - return ( - _load_llama_model( - args.model, - checkpoint=checkpoint_path, - checkpoint_dir=checkpoint_dir, - params_path=params_path, - use_kv_cache=args.use_kv_cache, - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - generate_full_logits=args.generate_full_logits, - weight_type=weight_type, - enable_dynamic_shape=args.enable_dynamic_shape, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - calibration_data=args.calibration_data, - tokenizer_path=args.tokenizer_path, - verbose=args.verbose, - max_seq_len=args.max_seq_length, - max_context_len=args.max_context_length, - input_prune_map_path=args.input_prune_map, - output_prune_map_path=args.output_prune_map, - metadata_str=args.metadata, - dtype_override=dtype_override, - args=args, - ) - .set_output_dir(output_dir_path) - .source_transform(_get_source_transforms(args.model, dtype_override, args)) + # At this point, the model is loaded in the default fp32. + edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) + edge_manager.set_output_dir(output_dir_path).source_transform( + _get_source_transforms(args.model, dtype_override, args) ) + return edge_manager + def get_quantizer_and_quant_params(args): pt2e_quant_params = get_pt2e_quantization_params( @@ -1006,6 +1005,8 @@ def _load_llama_model( else: raise ValueError(f"{modelname} is not a valid Llama model.") + torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None + model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( EagerModelFactory.create_model( module_name, @@ -1022,41 +1023,16 @@ def _load_llama_model( enable_dynamic_shape=enable_dynamic_shape, input_prune_map_path=input_prune_map_path, output_prune_map_path=output_prune_map_path, + dtype=torch_dtype, args=args, ) ) - if dtype_override: - assert isinstance( - dtype_override, DType - ), "Override dtype needs to be of type " - torch_dtype = dtype_override.to_torch_dtype() - logging.info(f"model.to {torch_dtype}") - model = model.to(dtype=torch_dtype) - dtype = dtype_override - else: - state_dict = model.state_dict() - dtype = state_dict[next(iter(state_dict))].dtype - assert dtype in [ - torch.bfloat16, - torch.float16, - torch.float32, - ], f"Only support bfloat16, fp16 or fp32 got {dtype}" - logging.info(f"Loaded model with dtype={dtype}") - - if dtype == torch.bfloat16: - dtype = DType.bf16 - elif dtype == torch.float16: - dtype = DType.fp16 - elif dtype == torch.float32: - dtype = DType.fp32 - else: - raise ValueError(f"Unsupported dtype {dtype}") return LLMEdgeManager( model=model, modelname=modelname, max_seq_len=model.max_seq_len, - dtype=dtype, + dtype=dtype_override, use_kv_cache=use_kv_cache, generate_full_logits=generate_full_logits, example_inputs=example_inputs, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index bc4fd6ccb11..833473167c2 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -122,9 +122,6 @@ def __init__(self, **kwargs): """ ) - # Get checkpoint dtype. - self.dtype = get_checkpoint_dtype(checkpoint) - with open(params_path, "r") as f: params = json.loads(f.read()) output_prune_map = None @@ -171,7 +168,9 @@ def __init__(self, **kwargs): # Within the device="meta" context, tensors that are created do not carry data. # They possess all other metadata a tensor carries such as size, stride, requires_grad. with torch.device("meta"): + # Model itself is loaded in default dtype, fp32. self.model_ = Transformer(model_args) + self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint) if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") @@ -241,6 +240,10 @@ def __init__(self, **kwargs): # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. + + # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is + # by default initialized to fp32. This is fine because every other supported type + # losslessly converts to fp32, so we don't lose precision here. missing, unexpected = self.model_.load_state_dict( checkpoint, strict=False, @@ -277,14 +280,7 @@ def __init__(self, **kwargs): self.model_ = prune_output_vocab(self.model_, output_prune_map) def get_eager_model(self) -> torch.nn.Module: - if self.dtype: - # convert to the type of the provided checkpoint - # input and output are torch.long, so signature unchanged - return self.model_.to(self.dtype) - else: - # int8 quantization code has some bf16, - # switch all to FP32 - return self.model_.to(torch.float32) + return self.model_ def get_example_inputs(self): if self.use_kv_cache: