Skip to content

Commit

Permalink
Fix tokenizer name and revision (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
yadavsahil197 committed Jun 3, 2024
1 parent 728e37b commit 12f72b2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
9 changes: 7 additions & 2 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@
}
METRIC_TABLE_STYLE = "cyan bold"

DEFAULT_TOKENIZATION_MODEL = {
"pretrained_model_name_or_path": "NousResearch/Llama-2-13b-chat-hf",
"revision": "d73f5fa9c4bc135502e04c27b39660747172d76b",
}

MERGE_FUNCTION = {
AggregationFunction.MAX: np.max,
AggregationFunction.MEAN: np.mean,
Expand Down Expand Up @@ -115,7 +120,7 @@ def __init__(
if self.config.confidence_chunk_column():
if not confidence_tokenizer:
self.confidence_tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-13b-chat-hf"
**DEFAULT_TOKENIZATION_MODEL
)
else:
self.confidence_tokenizer = confidence_tokenizer
Expand Down Expand Up @@ -705,7 +710,7 @@ def get_num_tokens(self, inp: str) -> int:
if not self.confidence_tokenizer:
logger.warning("Confidence tokenizer is not set. Using default tokenizer.")
self.confidence_tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-13b-chat-hf"
**DEFAULT_TOKENIZATION_MODEL
)
"""Returns the number of tokens in the prompt"""
return len(self.confidence_tokenizer.encode(str(inp)))
Expand Down
10 changes: 8 additions & 2 deletions src/autolabel/models/refuel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ class UnretryableError(Exception):


class RefuelLLM(BaseModel):
DEFAULT_TOKENIZATION_MODEL = "NousResearch/Llama-2-13b-chat-hf"
DEFAULT_TOKENIZATION_MODEL = {
"pretrained_model_name_or_path": "NousResearch/Llama-2-13b-chat-hf",
"revision": "d73f5fa9c4bc135502e04c27b39660747172d76b",
}

DEFAULT_CONTEXT_LENGTH = 3250
DEFAULT_CONNECT_TIMEOUT = 10
DEFAULT_READ_TIMEOUT = 120
Expand All @@ -57,7 +61,9 @@ def __init__(
self.model_name = config.model_name()
model_params = config.model_params()
self.model_params = {**self.DEFAULT_PARAMS, **model_params}
self.tokenizer = AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZATION_MODEL)
self.tokenizer = AutoTokenizer.from_pretrained(
**self.DEFAULT_TOKENIZATION_MODEL
)

# initialize runtime
self.BASE_API = f"https://llm.refuel.ai/models/{self.model_name}/generate"
Expand Down
9 changes: 7 additions & 2 deletions src/autolabel/models/refuelV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ class UnretryableError(Exception):


class RefuelLLMV2(BaseModel):
DEFAULT_TOKENIZATION_MODEL = "NousResearch/Llama-2-13b-chat-hf"
DEFAULT_TOKENIZATION_MODEL = {
"pretrained_model_name_or_path": "NousResearch/Llama-2-13b-chat-hf",
"revision": "d73f5fa9c4bc135502e04c27b39660747172d76b",
}
DEFAULT_CONTEXT_LENGTH = 3250
DEFAULT_CONNECT_TIMEOUT = 10
DEFAULT_READ_TIMEOUT = 120
Expand Down Expand Up @@ -60,7 +63,9 @@ def __init__(
model_params = config.model_params()
self.model_params = {**self.DEFAULT_PARAMS, **model_params}
self.model_endpoint = config.model_endpoint()
self.tokenizer = AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZATION_MODEL)
self.tokenizer = AutoTokenizer.from_pretrained(
**self.DEFAULT_TOKENIZATION_MODEL
)
self.read_timeout = self.model_params.get(
"request_timeout", self.DEFAULT_READ_TIMEOUT
)
Expand Down
9 changes: 7 additions & 2 deletions src/autolabel/models/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@


class TGILLM(BaseModel):
DEFAULT_TOKENIZATION_MODEL = "NousResearch/Llama-2-13b-chat-hf"
DEFAULT_TOKENIZATION_MODEL = {
"pretrained_model_name_or_path": "NousResearch/Llama-2-13b-chat-hf",
"revision": "d73f5fa9c4bc135502e04c27b39660747172d76b",
}
DEFAULT_CONTEXT_LENGTH = 3250
DEFAULT_CONNECT_TIMEOUT = 10
DEFAULT_READ_TIMEOUT = 120
Expand All @@ -52,7 +55,9 @@ def __init__(
self.model_params = {**self.DEFAULT_PARAMS, **model_params}
if self.config.confidence():
self.model_params["decoder_input_details"] = True
self.tokenizer = AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZATION_MODEL)
self.tokenizer = AutoTokenizer.from_pretrained(
**self.DEFAULT_TOKENIZATION_MODEL
)

@retry(
reraise=True,
Expand Down

0 comments on commit 12f72b2

Please sign in to comment.