diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index adf27885d..f90b59c25 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -12,6 +12,8 @@ import torch +from torchchat.model import TransformerArgs + # support running without installing as a package wd = Path(__file__).parent.parent sys.path.append(str(wd.resolve())) @@ -32,7 +34,8 @@ def convert_hf_checkpoint( if model_name is None: model_name = model_dir.name - config = ModelArgs.from_name(model_name).transformer_args['text'] + config_args = ModelArgs.from_name(model_name).transformer_args['text'] + config = TransformerArgs.from_params(config_args) print(f"Model config {config.__dict__}") # Load the json file containing weight mapping