Skip to content

Commit

Permalink
Pass confidence endpoint in params (#841)
Browse files Browse the repository at this point in the history
* Pass confidence endpoint in params

* Fix tokenizer name and revision

* return 0 confidence when endpoint is not passed
  • Loading branch information
yadavsahil197 committed Jun 4, 2024
1 parent 12f72b2 commit 0a0d2b3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/autolabel/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tenacity import (
before_sleep_log,
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
)
Expand All @@ -28,10 +29,12 @@ class ConfidenceCalculator:
def __init__(
self,
score_type: str = "logprob_average",
endpoint: str = None,
llm: Optional[BaseModel] = None,
cache: Optional[BaseCache] = None,
) -> None:
self.score_type = score_type
self.endpoint = endpoint
self.llm = llm
self.cache = cache
self.tokens_to_ignore = {"<unk>", "", "\\n"}
Expand All @@ -40,7 +43,6 @@ def __init__(
"p_true": self.p_true,
"logprob_average_per_key": self.logprob_average_per_key,
}
self.BASE_API = "https://llm.refuel.ai/models/refuel-llm-v2-small/v2/confidence"
self.REFUEL_API_ENV = "REFUEL_API_KEY"
if self.REFUEL_API_ENV in os.environ and os.environ[self.REFUEL_API_ENV]:
self.REFUEL_API_KEY = os.environ[self.REFUEL_API_ENV]
Expand Down Expand Up @@ -244,6 +246,7 @@ async def calculate(self, model_generation: LLMAnnotation, **kwargs) -> float:
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=10),
before_sleep=before_sleep_log(logger, logging.WARNING),
retry=retry_if_not_exception_type(ValueError),
)
async def _call_with_retry(self, model_input, model_output):
payload = {
Expand All @@ -253,9 +256,11 @@ async def _call_with_retry(self, model_input, model_output):
]
}
headers = {"refuel_api_key": self.REFUEL_API_KEY}
if self.endpoint is None:
raise ValueError("Endpoint not provided")
async with httpx.AsyncClient() as client:
response = await client.post(
self.BASE_API, json=payload, headers=headers, timeout=30
self.endpoint, json=payload, headers=headers, timeout=30
)
# raise Exception if status != 200
response.raise_for_status()
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
transform_cache: Optional[BaseCache] = SQLAlchemyTransformCache(),
confidence_cache: Optional[BaseCache] = SQLAlchemyConfidenceCache(),
confidence_tokenizer: Optional[AutoTokenizer] = None,
confidence_endpoint: Optional[str] = None,
use_tqdm: Optional[bool] = False,
) -> None:
self.generation_cache = generation_cache
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(
score_type = "logprob_average_per_key"
self.confidence = ConfidenceCalculator(
score_type=score_type,
endpoint=confidence_endpoint,
llm=self.llm,
cache=self.confidence_cache,
)
Expand Down

0 comments on commit 0a0d2b3

Please sign in to comment.