diff --git a/dsp/modules/hf.py b/dsp/modules/hf.py index 7407b3f47e..cefcee9290 100644 --- a/dsp/modules/hf.py +++ b/dsp/modules/hf.py @@ -40,6 +40,7 @@ def __init__( "sequential", ] = "auto", token: Optional[str] = None, + model_kwargs: Optional[dict] = {}, ): """wrapper for Hugging Face models @@ -49,6 +50,7 @@ def __init__( is_client (bool, optional): whether to access models via client. Defaults to False. hf_device_map (str, optional): HF config strategy to load the model. Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto". + model_kwargs (dict, optional): additional kwargs to pass to the model constructor. Defaults to empty dict. """ super().__init__(model) @@ -59,6 +61,11 @@ def __init__( hf_autoconfig_kwargs = dict(token=token or os.environ.get("HF_TOKEN")) hf_autotokenizer_kwargs = hf_autoconfig_kwargs.copy() hf_automodel_kwargs = hf_autoconfig_kwargs.copy() + + # silently remove device_map from model_kwargs if it is present, as the option is provided in the constructor + if "device_map" in model_kwargs: + model_kwargs.pop("device_map") + hf_automodel_kwargs.update(model_kwargs) if not self.is_client: try: import torch