Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dsp/modules/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
"sequential",
] = "auto",
token: Optional[str] = None,
model_kwargs: Optional[dict] = {},
):
"""wrapper for Hugging Face models

Expand All @@ -49,6 +50,7 @@ def __init__(
is_client (bool, optional): whether to access models via client. Defaults to False.
hf_device_map (str, optional): HF config strategy to load the model.
Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto".
model_kwargs (dict, optional): additional kwargs to pass to the model constructor. Defaults to empty dict.
"""

super().__init__(model)
Expand All @@ -59,6 +61,11 @@ def __init__(
hf_autoconfig_kwargs = dict(token=token or os.environ.get("HF_TOKEN"))
hf_autotokenizer_kwargs = hf_autoconfig_kwargs.copy()
hf_automodel_kwargs = hf_autoconfig_kwargs.copy()

# silently remove device_map from model_kwargs if it is present, as the option is provided in the constructor
if "device_map" in model_kwargs:
model_kwargs.pop("device_map")
hf_automodel_kwargs.update(model_kwargs)
if not self.is_client:
try:
import torch
Expand Down