Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ default_stages: [commit]
default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.11
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
# - repo: https://github.com/astral-sh/ruff-pre-commit
# # Ruff version.
# rev: v0.1.11
# hooks:
# # Run the linter.
# - id: ruff
# args: [--fix]
# # Run the formatter.
# - id: ruff-format

- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
Expand Down Expand Up @@ -50,14 +50,14 @@ repos:
args:
- "--autofix"
- "--indent=2"
- repo: local
hooks:
- id: validate-commit-msg
name: Commit Message is Valid
language: pygrep
entry: ^(break|build|ci|docs|feat|fix|perf|refactor|style|test|ops|hotfix|release|maint|init|enh|revert)\([\w,\.,\-,\(,\),\/]+\)(!?)(:)\s{1}([\w,\W,:]+)
stages: [commit-msg]
args: [--negate]
# - repo: local
# hooks:
# - id: validate-commit-msg
# name: Commit Message is Valid
# language: pygrep
# entry: ^(break|build|ci|docs|feat|fix|perf|refactor|style|test|ops|hotfix|release|maint|init|enh|revert)\([\w,\.,\-,\(,\),\/]+\)(!?)(:)\s{1}([\w,\W,:]+)
# stages: [commit-msg]
# args: [--negate]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
Expand Down
76 changes: 59 additions & 17 deletions dsp/modules/hf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# from peft import PeftConfig, PeftModel
# from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, AutoConfig
import os
from typing import Literal, Optional

from dsp.modules.lm import LM

# from dsp.modules.finetuning.finetune_hf import preprocess_prompt


def openai_to_hf(**kwargs):
hf_kwargs = {}
for k, v in kwargs.items():
Expand All @@ -26,8 +28,19 @@ def openai_to_hf(**kwargs):


class HFModel(LM):
def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool = False,
hf_device_map: Literal["auto", "balanced", "balanced_low_0", "sequential"] = "auto"):
def __init__(
self,
model: str,
checkpoint: Optional[str] = None,
is_client: bool = False,
hf_device_map: Literal[
"auto",
"balanced",
"balanced_low_0",
"sequential",
] = "auto",
token: Optional[str] = None,
):
"""wrapper for Hugging Face models

Args:
Expand All @@ -42,6 +55,10 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
self.provider = "hf"
self.is_client = is_client
self.device_map = hf_device_map

hf_autoconfig_kwargs = dict(token=token or os.environ.get("HF_TOKEN"))
hf_autotokenizer_kwargs = hf_autoconfig_kwargs.copy()
hf_automodel_kwargs = hf_autoconfig_kwargs.copy()
if not self.is_client:
try:
import torch
Expand All @@ -52,40 +69,68 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
) from exc
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
architecture = AutoConfig.from_pretrained(model).__dict__["architectures"][0]
self.encoder_decoder_model = ("ConditionalGeneration" in architecture) or ("T5WithLMHeadModel" in architecture)
architecture = AutoConfig.from_pretrained(
model,
**hf_autoconfig_kwargs,
).__dict__["architectures"][0]
self.encoder_decoder_model = ("ConditionalGeneration" in architecture) or (
"T5WithLMHeadModel" in architecture
)
self.decoder_only_model = ("CausalLM" in architecture) or ("GPT2LMHeadModel" in architecture)
assert self.encoder_decoder_model or self.decoder_only_model, f"Unknown HuggingFace model class: {model}"
self.tokenizer = AutoTokenizer.from_pretrained(model if checkpoint is None else checkpoint)
assert (
self.encoder_decoder_model or self.decoder_only_model
), f"Unknown HuggingFace model class: {model}"
self.tokenizer = AutoTokenizer.from_pretrained(
model if checkpoint is None else checkpoint,
**hf_autotokenizer_kwargs,
)

self.rationale = True
AutoModelClass = AutoModelForSeq2SeqLM if self.encoder_decoder_model else AutoModelForCausalLM
if checkpoint:
# with open(os.path.join(checkpoint, '..', 'compiler_config.json'), 'r') as f:
# config = json.load(f)
self.rationale = False #config['rationale']
self.rationale = False # config['rationale']
# if config['peft']:
# peft_config = PeftConfig.from_pretrained(checkpoint)
# self.model = AutoModelClass.from_pretrained(peft_config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=hf_device_map)
# self.model = PeftModel.from_pretrained(self.model, checkpoint)
# else:
if self.device_map:
self.model = AutoModelClass.from_pretrained(checkpoint, device_map=self.device_map)
self.model = AutoModelClass.from_pretrained(
checkpoint,
device_map=self.device_map,
**hf_automodel_kwargs,
)
else:
self.model = AutoModelClass.from_pretrained(checkpoint).to(self.device)
self.model = AutoModelClass.from_pretrained(
checkpoint,
**hf_automodel_kwargs,
).to(self.device)
else:
if self.device_map:
self.model = AutoModelClass.from_pretrained(model, device_map=self.device_map)
self.model = AutoModelClass.from_pretrained(
model,
device_map=self.device_map,
**hf_automodel_kwargs,
)
else:
self.model = AutoModelClass.from_pretrained(model).to(self.device)
self.model = AutoModelClass.from_pretrained(
model,
**hf_automodel_kwargs,
).to(self.device)
self.drop_prompt_from_output = False
except ValueError:
self.model = AutoModelForCausalLM.from_pretrained(
model if checkpoint is None else checkpoint,
device_map=self.device_map,
**hf_automodel_kwargs,
)
self.drop_prompt_from_output = True
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer = AutoTokenizer.from_pretrained(
model,
**hf_autotokenizer_kwargs,
)
self.drop_prompt_from_output = True
self.history = []

Expand All @@ -111,7 +156,7 @@ def _generate(self, prompt, **kwargs):
# print(prompt)
if isinstance(prompt, dict):
try:
prompt = prompt['messages'][0]['content']
prompt = prompt["messages"][0]["content"]
except (KeyError, IndexError, TypeError):
print("Failed to extract 'content' from the prompt.")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
Expand All @@ -121,10 +166,7 @@ def _generate(self, prompt, **kwargs):
if self.drop_prompt_from_output:
input_length = inputs.input_ids.shape[1]
outputs = outputs[:, input_length:]
completions = [
{"text": c}
for c in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
]
completions = [{"text": c} for c in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
response = {
"prompt": prompt,
"choices": completions,
Expand Down
Loading