In [54]:
import requests
from typing import List, Dict, Any, Union, Optional, Set, Tuple
import math
from dataclasses import dataclass

config = {
    'api_url': 'http://localhost:8081',
    'logprobs': 100
}

@dataclass
class TokenProb:
    token: str
    log_prob: float


def get_logprobs_completion(prompt: Union[str, List[int]], config: Dict[str, Any]) -> List[TokenProb]:
    url: str = config['api_url'] + '/completion'
    logprobs: int = config['logprobs']    
    
    payload: Dict[str, Union[str, bool, float, int]] = {
        'prompt': prompt,
        'cache_prompt': True,
        'n_predict': 1,
        'temperature': 1.0,
        'samplers': ['temperature'],
        'n_probs': logprobs
    }
    
    try:
        response: requests.Response = requests.post(url, json=payload)
        response.raise_for_status()
        probs: List[Dict[str, Union[str, float]]] = response.json()['completion_probabilities'][0]['probs']
    except Exception as e:
        print(f"Error fetching logprobs: {e}")
        probs = []
    
    return [TokenProb(prob['tok_str'], prob['prob']) for prob in probs]


def XTC_sampler(token_probs: List[TokenProb], threshold: int = 0.2, n_tokens_above_threshold: int = 3) -> List[str]:
    """Exclude Top Choices sampler"""
    n_above_threshold = sum(1 for token_prob in token_probs if token_prob.log_prob > threshold)
    if n_above_threshold >= n_tokens_above_threshold:
        token_probs = token_probs[n_above_threshold-1:]
    return token_probs

def compute_entropy_varentropy(token_probs: List[TokenProb]) -> float:
    probs = [math.exp(token_prob.log_prob) for token_prob in token_probs]
    log_probs = [token_prob.log_prob for token_prob in token_probs]

    # entropy (E[log P(x)])
    entropy = -sum(prob * log_prob for prob, log_prob in zip(probs, log_probs))

    # expected square of log probability E[(log P(x))^2]
    expected_log_prob_squared = sum(prob * (log_prob ** 2) for prob, log_prob in zip(probs, log_probs))

    # varentropy = E[(log P(x))^2] - (E[log P(x)])^2
    varentropy = expected_log_prob_squared - (entropy ** 2)

    return varentropy, entropy


In [68]:
def get_logprobs_chat_completion(prompt: Union[str, List[int]], config: Dict[str, Any]) -> List[TokenProb]:
    url: str = config['api_url'] + '/v1/chat/completions'
    logprobs: int = config['logprobs']    
    
    payload: Dict[str, Union[str, bool, float, int]] = {
        'prompt': prompt,
        'cache_prompt': True,
        'n_predict': 1,
        'temperature': 1.0,
        #'samplers': ['top_k'],
        'n_probs': logprobs
    }
    
    try:
        response: requests.Response = requests.post(url, json=payload)
        response.raise_for_status()
        probs: List[Dict[str, Union[str, float]]] = response.json()['completion_probabilities'][0]['probs']
    except Exception as e:
        print(f"Error fetching logprobs: {e}")
        probs = []
    
    return [TokenProb(prob['tok_str'], prob['prob']) for prob in probs]

messages = [
    {"role": "system", "content": "You are a princial engineer at Google, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags process the question carefully and think step-by-step about the solution, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."},
    {"role": "user", "content": "Which number is larger, 9.9 or 9.11?"},
]

payload: Dict[str, Union[str, bool, float, int]] = {
        "model": "gpt-3.5-turbo",
        "messages": messages,
        'cache_prompt': True,
        'n_probs': 10,
        'n_predict': 1,
}

url: str = config['api_url'] + '/v1/chat/completions'
response: requests.Response = requests.post(url, json=payload)
response.raise_for_status()

print(response.json()["choices"][0]["message"]["content"])

<


In [67]:
response.json()

{'choices': [{'finish_reason': 'stop',
   'index': 0,
   'message': {'content': "To compare 9.9 and 9.11, we need to compare their decimal expansions.\n\n9.9 can be written as 9 + 0.1 (where the + sign indicates adding)\n\n9.11 can be written as 9 + 0.1 + 0.1 (where the + sign indicates adding)\n\nBoth numbers are equal to 9. Since 9 is less than 9.1, 9.9 is also less than 9.11, and therefore 9.9 is larger than 9.11.\n\n<reflection> I initially considered comparing the numbers as decimals, but that led to a different conclusion. In reality, we should have compared their integer parts (9 and 9) and then compared their fractional parts (0.1 and 0.1), but that didn't change the outcome. It was a mistake to initially consider comparing the numbers as decimals.",
    'role': 'assistant'}}],
 'created': 1727953433,
 'model': 'gpt-3.5-turbo',
 'object': 'chat.completion',
 'usage': {'completion_tokens': 189,
  'prompt_tokens': 103,
  'total_tokens': 292},
 'id': 'chatcmpl-kOopC0AO5pnxPxDNPeVP

In [63]:
response.json()["choices"][0]["message"]["content"]

"<reflection> Ah, an easy one! \n\nLet's analyze the situation step by step:\n\n1. We have two decimal numbers: 9.9 and 9.11.\n2. Since both numbers have the same decimal part (0.9), we can compare them by looking at the whole number part.\n3. The whole number part of 9.9 is 9, and the whole number part of 9.11 is also 9.\n4. Since both whole number"

In [53]:
prompt = "The quick brown fox jumps over the lazy dog."
token_probs: List[TokenProb] = get_logprobs_completion(prompt, config)
token_probs_list = [token_prob.token for token_prob in token_probs]
# print(token_probs)
XTC_sampler_probs = XTC_sampler(token_probs)
print()
# print(XTC_sampler_probs)
XTC_sampler_probs_list = [token_prob.token for token_prob in XTC_sampler_probs]
varentropy(token_probs)




(-1.3666777774741319, -1.3295235793825229)