From 87b9f32f7f77fefaf8df716648891a93b6df32bd Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 26 Mar 2025 14:22:06 -0700 Subject: [PATCH] Remove exception fall back on checkpoint loading --- .ci/scripts/test_model.sh | 4 ++-- examples/models/llama/model.py | 43 ++++++++++++++-------------------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index cd543ff1424..7704ddac809 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -96,7 +96,7 @@ test_model() { bash examples/models/llama/install_requirements.sh # Test export_llama script: python3 -m examples.models.llama.export_llama. # Use Llama random checkpoint with Qwen 2.5 1.5b model configuration. - "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/qwen2_5/1_5b_config.json + "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -p examples/models/qwen2_5/1_5b_config.json rm "./${MODEL_NAME}.pte" return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears. fi @@ -104,7 +104,7 @@ test_model() { # Install requirements for export_llama bash examples/models/llama/install_requirements.sh # Test export_llama script: python3 -m examples.models.llama.export_llama. - "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi_4_mini/config.json + "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -p examples/models/phi_4_mini/config.json run_portable_executor_runner rm "./${MODEL_NAME}.pte" return diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index f90ceb8c415..19829576482 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -244,33 +244,24 @@ def __init__(self, **kwargs): ) missing, unexpected = None, None - try: - # assign=True: load params/buffers by assignment instead of performing an in-place copy. - # Because we are using device="meta", tensors do not have memory associated with them - # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. - - # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is - # by default initialized to fp32. This is fine because every other supported type - # losslessly converts to fp32, so we don't lose precision here. - if checkpoint: - missing, unexpected = self.model_.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) # self.model_ = Transformer(gptconf) - else: - print("Checkpoint not provided, defaulting weights to zeros.") - self.model_.to_empty(device="cpu") - for p in self.model_.parameters(): - p.data.fill_(0) - for b in self.model_.buffers(): - b.data.fill_(0) - except RuntimeError as e: - print( - f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: {e}." - ) - # Need to provide concrete (empty) values for meta-initialized tensors for quantization. + # assign=True: load params/buffers by assignment instead of performing an in-place copy. + # Because we are using device="meta", tensors do not have memory associated with them + # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. + + # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is + # by default initialized to fp32. This is fine because every other supported type + # losslessly converts to fp32, so we don't lose precision here. + if checkpoint: + missing, unexpected = self.model_.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) # self.model_ = Transformer(gptconf) + else: + print("Checkpoint not provided, defaulting weights to zeros.") self.model_.to_empty(device="cpu") + # Need to provide concrete values for meta-initialized tensors for quantization. + # otherwise it is just filled with nan's. for p in self.model_.parameters(): p.data.fill_(0) for b in self.model_.buffers():