Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 39 additions & 63 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def build_args_parser() -> argparse.ArgumentParser:
default="fp32",
type=str,
choices=["fp32", "fp16", "bf16"],
help="Override the dtype of the model (default is the checkpoint dtype)."
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
help="Provide the dtype of the model. This must match up with the supported dtypes of the backends that you are using."
"Please be aware that only some backends support fp16 and bf16.",
)

parser.add_argument(
Expand Down Expand Up @@ -565,43 +565,42 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
output_dir_path = canonical_path(args.output_dir, dir=True)
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA

# dtype override
if args.dtype_override is not None:
dtype_override = DType[args.dtype_override]
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
dtype_override = DType["fp16"]
else:
dtype_override = None
# Convert dtype override string arg to actual type.
dtype_override = DType[args.dtype_override]

edge_manager = _load_llama_model(
args.model,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
weight_type=weight_type,
enable_dynamic_shape=args.enable_dynamic_shape,
calibration_tasks=args.calibration_tasks,
calibration_limit=args.calibration_limit,
calibration_seq_length=args.calibration_seq_length,
calibration_data=args.calibration_data,
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
max_context_len=args.max_context_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
dtype_override=dtype_override,
args=args,
)

return (
_load_llama_model(
args.model,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
weight_type=weight_type,
enable_dynamic_shape=args.enable_dynamic_shape,
calibration_tasks=args.calibration_tasks,
calibration_limit=args.calibration_limit,
calibration_seq_length=args.calibration_seq_length,
calibration_data=args.calibration_data,
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
max_context_len=args.max_context_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
dtype_override=dtype_override,
args=args,
)
.set_output_dir(output_dir_path)
.source_transform(_get_source_transforms(args.model, dtype_override, args))
# At this point, the model is loaded in the default fp32.
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
edge_manager.set_output_dir(output_dir_path).source_transform(
_get_source_transforms(args.model, dtype_override, args)
)

return edge_manager


def get_quantizer_and_quant_params(args):
pt2e_quant_params = get_pt2e_quantization_params(
Expand Down Expand Up @@ -1006,6 +1005,8 @@ def _load_llama_model(
else:
raise ValueError(f"{modelname} is not a valid Llama model.")

torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None

model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
EagerModelFactory.create_model(
module_name,
Expand All @@ -1022,41 +1023,16 @@ def _load_llama_model(
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
dtype=torch_dtype,
args=args,
)
)
if dtype_override:
assert isinstance(
dtype_override, DType
), "Override dtype needs to be of type <DType>"
torch_dtype = dtype_override.to_torch_dtype()
logging.info(f"model.to {torch_dtype}")
model = model.to(dtype=torch_dtype)
dtype = dtype_override
else:
state_dict = model.state_dict()
dtype = state_dict[next(iter(state_dict))].dtype
assert dtype in [
torch.bfloat16,
torch.float16,
torch.float32,
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
logging.info(f"Loaded model with dtype={dtype}")

if dtype == torch.bfloat16:
dtype = DType.bf16
elif dtype == torch.float16:
dtype = DType.fp16
elif dtype == torch.float32:
dtype = DType.fp32
else:
raise ValueError(f"Unsupported dtype {dtype}")

return LLMEdgeManager(
model=model,
modelname=modelname,
max_seq_len=model.max_seq_len,
dtype=dtype,
dtype=dtype_override,
use_kv_cache=use_kv_cache,
generate_full_logits=generate_full_logits,
example_inputs=example_inputs,
Expand Down
18 changes: 7 additions & 11 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def __init__(self, **kwargs):
"""
)

# Get checkpoint dtype.
self.dtype = get_checkpoint_dtype(checkpoint)

with open(params_path, "r") as f:
params = json.loads(f.read())
output_prune_map = None
Expand Down Expand Up @@ -171,7 +168,9 @@ def __init__(self, **kwargs):
# Within the device="meta" context, tensors that are created do not carry data.
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
with torch.device("meta"):
# Model itself is loaded in default dtype, fp32.
self.model_ = Transformer(model_args)
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

if "int8" in str(checkpoint_path):
print("Using int8 weight-only quantization!")
Expand Down Expand Up @@ -241,6 +240,10 @@ def __init__(self, **kwargs):
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.

# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
# by default initialized to fp32. This is fine because every other supported type
# losslessly converts to fp32, so we don't lose precision here.
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
Expand Down Expand Up @@ -277,14 +280,7 @@ def __init__(self, **kwargs):
self.model_ = prune_output_vocab(self.model_, output_prune_map)

def get_eager_model(self) -> torch.nn.Module:
if self.dtype:
# convert to the type of the provided checkpoint
# input and output are torch.long, so signature unchanged
return self.model_.to(self.dtype)
else:
# int8 quantization code has some bf16,
# switch all to FP32
return self.model_.to(torch.float32)
return self.model_

def get_example_inputs(self):
if self.use_kv_cache:
Expand Down
Loading