-
Notifications
You must be signed in to change notification settings - Fork 0
Add VLLMSoftEntailer for LLM-based conditional probability estimation #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+208
−2
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,4 +13,5 @@ dependencies = [ | |
| "numpy", | ||
| "scipy", | ||
| "transformers", | ||
| "openai", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| torch | ||
| transformers | ||
| numpy | ||
| scipy | ||
| scipy | ||
| openai |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,3 @@ | ||
| from .entailer import Entailer, EntailerInstance | ||
| from .soft_entailer import SoftEntailer | ||
| from .soft_entailer import SoftEntailer | ||
| from .vllm_soft_entailer import VLLMSoftEntailer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| """VLLM-backed soft entailer for conditional probability estimation. | ||
|
|
||
| Uses a VLLM-served LLM (e.g. Zhengping/conditional-probability-regression) | ||
| that estimates p(hypothesis | premise) by decoding a distribution over | ||
| special label-level tokens and computing a weighted average score. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import math | ||
| from typing import Text, List, Optional, Dict | ||
|
|
||
| from openai import AsyncOpenAI | ||
| from openai.types.chat import ChatCompletion | ||
|
|
||
| from .entailer import Entailer | ||
| from ..utils.instances import EntailerInstance | ||
|
|
||
|
|
||
| class VLLMSoftEntailer(Entailer): | ||
| """Soft entailer backed by a VLLM OpenAI-compatible API endpoint. | ||
|
|
||
| The hosted model is expected to use special ``<|label_level_N|>`` tokens | ||
| whose softmax-weighted midpoint scores yield a probability in [0, 1]. | ||
| This is the inference protocol used by | ||
| ``Zhengping/conditional-probability-regression``. | ||
|
|
||
| All requests within a batch are dispatched concurrently via | ||
| :class:`openai.AsyncOpenAI`. The client handles transient failures | ||
| (connection errors, 429, >=500) with exponential-backoff retries | ||
| controlled by *max_retries*. | ||
| """ | ||
|
|
||
| _PROMPT_TEMPLATE = ( | ||
| '### Question: Given the premise "{premise}", ' | ||
| 'how likely is it that the hypothesis "{hypothesis}" is true?\n\n' | ||
| ) | ||
| _COMPLETION_PREFIX = "### Answer:" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: Text, | ||
| api_base: Text = "http://localhost:8000/v1", | ||
| num_labels: int = 10, | ||
| internal_batch_size: int = 16, | ||
| cache_dir: Optional[Text] = None, | ||
| top_logprobs: int = 20, | ||
| api_key: Text = "EMPTY", | ||
| max_retries: int = 3, | ||
| timeout: float = 60.0, | ||
| ): | ||
| super().__init__( | ||
| model_name=model_name, | ||
| device="cpu", | ||
| internal_batch_size=internal_batch_size, | ||
| max_length=512, | ||
| cache_dir=cache_dir, | ||
| ) | ||
| self._api_base = api_base.rstrip("/") | ||
| self._num_labels = num_labels | ||
| self._top_logprobs = max(top_logprobs, num_labels) | ||
|
|
||
| # Store client-construction kwargs so that a fresh AsyncOpenAI | ||
| # client can be created inside each event loop spawned by | ||
| # asyncio.run(). Re-using one AsyncOpenAI instance across | ||
| # different loops would bind the underlying httpx connection pool | ||
| # to a stale loop. | ||
| self._client_kwargs = dict( | ||
| base_url=self._api_base, | ||
| api_key=api_key, | ||
| max_retries=max_retries, | ||
| timeout=timeout, | ||
| ) | ||
|
|
||
| # Pre-compute label token strings and their midpoint score values. | ||
| # Token format mirrors the vocabulary of the target model where each | ||
| # ``<|label_level_i|>`` token is mapped to the midpoint of the i-th | ||
| # uniform bin over [0, 1]. | ||
| self._label_tokens: List[Text] = [ | ||
| f" <|label_level_{i}|>" for i in range(num_labels) | ||
| ] | ||
| step_size = 1.0 / num_labels | ||
| self._label_scores: List[float] = [ | ||
| i * step_size + 0.5 * step_size for i in range(num_labels) | ||
| ] | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Override base-class hooks | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _load_model(self): | ||
| """No local model to load — set sentinels so the base ``__call__`` | ||
| does not attempt to reload on every invocation.""" | ||
| self._model = "vllm" | ||
| self._tokenizer = "vllm" | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| @staticmethod | ||
| def _softmax(values: List[float]) -> List[float]: | ||
| """Numerically-stable softmax over a list of logprobs.""" | ||
| max_val = max(values) | ||
| exps = [math.exp(v - max_val) for v in values] | ||
| total = sum(exps) | ||
| return [e / total for e in exps] | ||
|
|
||
| def _extract_score(self, completion: ChatCompletion) -> float: | ||
| """Compute the weighted-average probability from a chat completion. | ||
|
|
||
| 1. Collect the log-probabilities of every ``<|label_level_*|>`` | ||
| token that appears in the ``top_logprobs`` of the first | ||
| generated token. | ||
| 2. Apply softmax **only** over those label tokens. | ||
| 3. Return the dot product with the pre-computed midpoint scores. | ||
| """ | ||
| choice = completion.choices[0] | ||
|
|
||
| if choice.logprobs is None or not choice.logprobs.content: | ||
| return 0.5 | ||
|
|
||
| first_token_info = choice.logprobs.content[0] | ||
|
|
||
| # Map token string → logprob from the top_logprobs list | ||
| token_logprob_map: Dict[Text, float] = { | ||
| entry.token: entry.logprob | ||
| for entry in (first_token_info.top_logprobs or []) | ||
| } | ||
|
|
||
| # Look up each label token; use a very negative value for any | ||
| # label that did not appear in the top-k. | ||
| label_logprobs = [ | ||
| token_logprob_map.get(tok, -100.0) for tok in self._label_tokens | ||
| ] | ||
|
|
||
| probs = self._softmax(label_logprobs) | ||
| score = sum(p * s for p, s in zip(probs, self._label_scores)) | ||
| return score | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Async internals | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| async def _score_instance( | ||
| self, | ||
| client: AsyncOpenAI, | ||
| instance: EntailerInstance, | ||
| ) -> float: | ||
| """Send a single chat-completion request and return the score.""" | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": self._PROMPT_TEMPLATE.format( | ||
| premise=instance.premise, | ||
| hypothesis=instance.hypothesis, | ||
| ), | ||
| }, | ||
| { | ||
| "role": "assistant", | ||
| "content": self._COMPLETION_PREFIX, | ||
| }, | ||
| ] | ||
|
|
||
| completion = await client.chat.completions.create( | ||
| model=self._model_name, | ||
| messages=messages, | ||
| max_tokens=1, | ||
| logprobs=True, | ||
| top_logprobs=self._top_logprobs, | ||
| temperature=0, | ||
| extra_body={ | ||
| # vLLM-specific: continue from the assistant prefix | ||
| # rather than starting a new assistant turn. | ||
| "continue_final_message": True, | ||
| }, | ||
| ) | ||
| return self._extract_score(completion) | ||
|
|
||
| async def _async_call_batch( | ||
| self, instances: List[EntailerInstance] | ||
| ) -> List[float]: | ||
| """Fire all requests concurrently inside a single event loop.""" | ||
| async with AsyncOpenAI(**self._client_kwargs) as client: | ||
| tasks = [ | ||
| self._score_instance(client, inst) for inst in instances | ||
| ] | ||
| return list(await asyncio.gather(*tasks)) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Core batch call (sync entry-point used by the base class) | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _call_batch( | ||
| self, instances: List[EntailerInstance] | ||
| ) -> List[float]: | ||
| """Query the VLLM server for a batch of instances. | ||
|
|
||
| All requests are dispatched concurrently via ``asyncio.gather``. | ||
| The :class:`AsyncOpenAI` client automatically retries transient | ||
| errors (connection failures, HTTP 429 / >=500) with | ||
| exponential backoff. | ||
| """ | ||
| return asyncio.run(self._async_call_batch(instances)) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_call_batchunconditionally wraps every batch inasyncio.run, which raisesRuntimeErrorwhenever the caller already has an active event loop (e.g., Jupyter notebooks,pytest-asyncio, FastAPI workers). In those common environments this new entailer cannot be used at all, so experiments that switch toVLLMSoftEntailerwill fail before scoring; this path is reached through the normalEntailer.__call__flow, not just a special API.Useful? React with 👍 / 👎.