From f6f3adaebc6d1d27281989be219625b1d9b9dd4b Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Fri, 9 Feb 2024 16:27:38 -0800 Subject: [PATCH] CI failures with dtype-override (#1919) Summary: language_llama failure: https://www.internalfb.com/sandcastle/workflow/3071454945868675212 llama stories failure: https://www.internalfb.com/intern/testinfra/diagnostics/5066549797810715.562950083098286.1707509873/ Reviewed By: larryliu0820, angelayi Differential Revision: D53625845 --- examples/models/llama2/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 52ed14f2b53..32fe4ae04dc 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -499,6 +499,12 @@ def __init__(self, **kwargs): device = "cpu" # flake8: noqa: TOR102 checkpoint = torch.load(checkpoint_path, map_location=device) + if kwargs.get("fairseq2", False): + print("Using fairseq2 checkpoint") + checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) + if "model" in checkpoint: + # NB: some checkpoint contains a "model" field, which is the actual weights dict + checkpoint = checkpoint["model"] # get checkpoint dtype self.dtype = None if len(checkpoint) > 0: @@ -513,12 +519,6 @@ def __init__(self, **kwargs): print( f"Mixed dtype model. Dtype of {first.key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" ) - if kwargs.get("fairseq2", False): - print("Using fairseq2 checkpoint") - checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) - if "model" in checkpoint: - # NB: some checkpoint contains a "model" field, which is the actual weights dict - checkpoint = checkpoint["model"] with open(params_path, "r") as f: params = json.loads(f.read()) max_seq_len = 128