From aee64878f77daf61bb3556f3d0ace50ba72ea484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 6 Nov 2024 15:20:28 +0100 Subject: [PATCH 1/3] fix: allow multiple weight mapping files for mistral MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Downloading a Mistral model fails because it includes multiple weight mapping files. The regression was introduced in commit `766bee9f4a1fcb187fae543a525495d3ff482097`. I'm unclear on the original intent, but perhaps the exception was meant to apply only to Granite models. This isn’t an ideal fix, but it does enable Mistral to be downloaded and used for chat. Signed-off-by: Sébastien Han --- torchchat/cli/convert_hf_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f428e4cc6..06b2054d3 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -41,7 +41,8 @@ def convert_hf_checkpoint( # Load the json file containing weight mapping model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] - assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" + if "mistral" not in model_name: + assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" if len(model_map_json_matches): model_map_json = model_map_json_matches[0] else: From 295ae2a2805f598be419faca18529ae49656968f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 6 Nov 2024 10:41:21 -0700 Subject: [PATCH 2/3] fix(download): Fix safetensors/bin/pth download logic The previous logic didn't handle .bin files, so if a model (like mistral) has both .bin and .safetensors, it would download both. Branch: download-fix Signed-off-by: Gabe Goodhart --- torchchat/cli/download.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index f334eb555..4da2bc390 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -35,11 +35,12 @@ def _download_hf_snapshot( model_info = model_info(model_config.distribution_path, token=hf_token) model_fnames = [f.rfilename for f in model_info.siblings] - # Check the model config for preference between safetensors and pth + # Check the model config for preference between safetensors and pth/bin has_pth = any(f.endswith(".pth") for f in model_fnames) + has_bin = any(f.endswith(".bin") for f in model_fnames) has_safetensors = any(f.endswith(".safetensors") for f in model_fnames) - # If told to prefer safetensors, ignore pth files + # If told to prefer safetensors, ignore pth/bin files if model_config.prefer_safetensors: if not has_safetensors: print( @@ -47,10 +48,10 @@ def _download_hf_snapshot( file=sys.stderr, ) exit(1) - ignore_patterns = "*.pth" + ignore_patterns = ["*.pth", "*.bin"] # If the model has both, prefer pth files over safetensors - elif has_pth and has_safetensors: + elif (has_pth or has_bin) and has_safetensors: ignore_patterns = "*safetensors*" # Otherwise, download everything From 5747a71553147ebf88cb0f2bfc33d93d4845b4e3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 6 Nov 2024 10:42:38 -0700 Subject: [PATCH 3/3] fix(convert hf): Better logic to handle multiple weight mapping files This will not actually be needed for mistral with the fix in download to handle .bin files, but it may be needed for other models, so it's worth having. Branch: download-fix Signed-off-by: Gabe Goodhart --- torchchat/cli/convert_hf_checkpoint.py | 36 +++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index 06b2054d3..122ab0f28 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -39,20 +39,14 @@ def convert_hf_checkpoint( config = TransformerArgs.from_params(config_args) print(f"Model config {config.__dict__}") - # Load the json file containing weight mapping + # Find all candidate weight mapping index files model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))] - if "mistral" not in model_name: - assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files" - if len(model_map_json_matches): - model_map_json = model_map_json_matches[0] - else: - model_map_json = model_dir / "pytorch_model.bin.index.json" # If there is no weight mapping, check for a consolidated model and # tokenizer we can move. Llama 2 and Mistral have weight mappings, while # Llama 3 has a consolidated model and tokenizer. # Otherwise raise an error. - if not model_map_json.is_file(): + if not model_map_json_matches: consolidated_pth = model_dir / "original" / "consolidated.00.pth" tokenizer_pth = model_dir / "original" / "tokenizer.model" if consolidated_pth.is_file() and tokenizer_pth.is_file(): @@ -69,11 +63,30 @@ def convert_hf_checkpoint( return else: raise RuntimeError( - f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}" + f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}" ) - with open(model_map_json) as json_map: - bin_index = json.load(json_map) + # Load the json file(s) containing weight mapping + # + # NOTE: If there are multiple index files, there are two possibilities: + # 1. The files could be mapped to different weight format files (e.g. .bin + # vs .safetensors) + # 2. The files could be split subsets of the mappings that need to be + # merged + # + # In either case, we can simply keep the mappings where the target file is + # valid in the model dir. + bin_index = {} + for weight_map_file in model_map_json_matches: + with open(weight_map_file, "r") as handle: + weight_map = json.load(handle) + valid_mappings = { + k: model_dir / v + for (k, v) in weight_map.get("weight_map", {}).items() + if (model_dir / v).is_file() + } + bin_index.update(valid_mappings) + bin_files = set(bin_index.values()) weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", @@ -97,7 +110,6 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): return (