Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ dependencies = [
"numpy",
"scipy",
"transformers",
"openai",
]
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch
transformers
numpy
scipy
scipy
openai
3 changes: 2 additions & 1 deletion src/core/entailers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .entailer import Entailer, EntailerInstance
from .soft_entailer import SoftEntailer
from .soft_entailer import SoftEntailer
from .vllm_soft_entailer import VLLMSoftEntailer
203 changes: 203 additions & 0 deletions src/core/entailers/vllm_soft_entailer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""VLLM-backed soft entailer for conditional probability estimation.

Uses a VLLM-served LLM (e.g. Zhengping/conditional-probability-regression)
that estimates p(hypothesis | premise) by decoding a distribution over
special label-level tokens and computing a weighted average score.
"""

import asyncio
import math
from typing import Text, List, Optional, Dict

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion

from .entailer import Entailer
from ..utils.instances import EntailerInstance


class VLLMSoftEntailer(Entailer):
"""Soft entailer backed by a VLLM OpenAI-compatible API endpoint.

The hosted model is expected to use special ``<|label_level_N|>`` tokens
whose softmax-weighted midpoint scores yield a probability in [0, 1].
This is the inference protocol used by
``Zhengping/conditional-probability-regression``.

All requests within a batch are dispatched concurrently via
:class:`openai.AsyncOpenAI`. The client handles transient failures
(connection errors, 429, >=500) with exponential-backoff retries
controlled by *max_retries*.
"""

_PROMPT_TEMPLATE = (
'### Question: Given the premise "{premise}", '
'how likely is it that the hypothesis "{hypothesis}" is true?\n\n'
)
_COMPLETION_PREFIX = "### Answer:"

def __init__(
self,
model_name: Text,
api_base: Text = "http://localhost:8000/v1",
num_labels: int = 10,
internal_batch_size: int = 16,
cache_dir: Optional[Text] = None,
top_logprobs: int = 20,
api_key: Text = "EMPTY",
max_retries: int = 3,
timeout: float = 60.0,
):
super().__init__(
model_name=model_name,
device="cpu",
internal_batch_size=internal_batch_size,
max_length=512,
cache_dir=cache_dir,
)
self._api_base = api_base.rstrip("/")
self._num_labels = num_labels
self._top_logprobs = max(top_logprobs, num_labels)

# Store client-construction kwargs so that a fresh AsyncOpenAI
# client can be created inside each event loop spawned by
# asyncio.run(). Re-using one AsyncOpenAI instance across
# different loops would bind the underlying httpx connection pool
# to a stale loop.
self._client_kwargs = dict(
base_url=self._api_base,
api_key=api_key,
max_retries=max_retries,
timeout=timeout,
)

# Pre-compute label token strings and their midpoint score values.
# Token format mirrors the vocabulary of the target model where each
# ``<|label_level_i|>`` token is mapped to the midpoint of the i-th
# uniform bin over [0, 1].
self._label_tokens: List[Text] = [
f" <|label_level_{i}|>" for i in range(num_labels)
]
step_size = 1.0 / num_labels
self._label_scores: List[float] = [
i * step_size + 0.5 * step_size for i in range(num_labels)
]

# ------------------------------------------------------------------
# Override base-class hooks
# ------------------------------------------------------------------

def _load_model(self):
"""No local model to load — set sentinels so the base ``__call__``
does not attempt to reload on every invocation."""
self._model = "vllm"
self._tokenizer = "vllm"

# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------

@staticmethod
def _softmax(values: List[float]) -> List[float]:
"""Numerically-stable softmax over a list of logprobs."""
max_val = max(values)
exps = [math.exp(v - max_val) for v in values]
total = sum(exps)
return [e / total for e in exps]

def _extract_score(self, completion: ChatCompletion) -> float:
"""Compute the weighted-average probability from a chat completion.

1. Collect the log-probabilities of every ``<|label_level_*|>``
token that appears in the ``top_logprobs`` of the first
generated token.
2. Apply softmax **only** over those label tokens.
3. Return the dot product with the pre-computed midpoint scores.
"""
choice = completion.choices[0]

if choice.logprobs is None or not choice.logprobs.content:
return 0.5

first_token_info = choice.logprobs.content[0]

# Map token string → logprob from the top_logprobs list
token_logprob_map: Dict[Text, float] = {
entry.token: entry.logprob
for entry in (first_token_info.top_logprobs or [])
}

# Look up each label token; use a very negative value for any
# label that did not appear in the top-k.
label_logprobs = [
token_logprob_map.get(tok, -100.0) for tok in self._label_tokens
]

probs = self._softmax(label_logprobs)
score = sum(p * s for p, s in zip(probs, self._label_scores))
return score

# ------------------------------------------------------------------
# Async internals
# ------------------------------------------------------------------

async def _score_instance(
self,
client: AsyncOpenAI,
instance: EntailerInstance,
) -> float:
"""Send a single chat-completion request and return the score."""
messages = [
{
"role": "user",
"content": self._PROMPT_TEMPLATE.format(
premise=instance.premise,
hypothesis=instance.hypothesis,
),
},
{
"role": "assistant",
"content": self._COMPLETION_PREFIX,
},
]

completion = await client.chat.completions.create(
model=self._model_name,
messages=messages,
max_tokens=1,
logprobs=True,
top_logprobs=self._top_logprobs,
temperature=0,
extra_body={
# vLLM-specific: continue from the assistant prefix
# rather than starting a new assistant turn.
"continue_final_message": True,
},
)
return self._extract_score(completion)

async def _async_call_batch(
self, instances: List[EntailerInstance]
) -> List[float]:
"""Fire all requests concurrently inside a single event loop."""
async with AsyncOpenAI(**self._client_kwargs) as client:
tasks = [
self._score_instance(client, inst) for inst in instances
]
return list(await asyncio.gather(*tasks))

# ------------------------------------------------------------------
# Core batch call (sync entry-point used by the base class)
# ------------------------------------------------------------------

def _call_batch(
self, instances: List[EntailerInstance]
) -> List[float]:
"""Query the VLLM server for a batch of instances.

All requests are dispatched concurrently via ``asyncio.gather``.
The :class:`AsyncOpenAI` client automatically retries transient
errors (connection failures, HTTP 429 / >=500) with
exponential backoff.
"""
return asyncio.run(self._async_call_batch(instances))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid asyncio.run in synchronous batch scorer

_call_batch unconditionally wraps every batch in asyncio.run, which raises RuntimeError whenever the caller already has an active event loop (e.g., Jupyter notebooks, pytest-asyncio, FastAPI workers). In those common environments this new entailer cannot be used at all, so experiments that switch to VLLMSoftEntailer will fail before scoring; this path is reached through the normal Entailer.__call__ flow, not just a special API.

Useful? React with 👍 / 👎.