In [3]:
import argparse
import json
import pathlib
from enum import Enum

import fireworks.client.api
import numpy as np
from fireworks.client import AsyncFireworks

In [9]:
MODELS = [
    "accounts/fireworks/models/llama-v3p1-8b-instruct",
    "accounts/fireworks/models/llama-v3p1-70b-instruct",
    "accounts/fireworks/models/llama-v3p1-405b-instruct",
]

In [6]:
def get_model_written_tokens_logprobs(
    logprobs: fireworks.client.api.LogProbs,
) -> tuple[list[str], list[float]]:
    # iterate through tokens, find the logprobs that were written by the model
    # and return them

    written_logprobs = []
    written_tokens = []

    # state machine to manage where where in the token stream we are
    class ModelState(Enum):
        MODEL_WRITTEN_TEXT = 1
        USER_WRITTEN_TEXT = 2
        EXPECT_ASSISTANT_TOKEN = 3

    state = ModelState.USER_WRITTEN_TEXT

    for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs, strict=False):
        match state:
            case ModelState.MODEL_WRITTEN_TEXT:
                if token == "<|end_header_id|>":
                    continue
                elif token == "<|eot_id|>":
                    state = ModelState.USER_WRITTEN_TEXT
                elif token == "<|start_header_id|>":
                    state = ModelState.EXPECT_ASSISTANT_TOKEN
                else:
                    written_logprobs.append(logprob)
                    written_tokens.append(token)
            case ModelState.USER_WRITTEN_TEXT:
                if token == "<|start_header_id|>":
                    state = ModelState.EXPECT_ASSISTANT_TOKEN
            case ModelState.EXPECT_ASSISTANT_TOKEN:
                if token == "assistant":
                    state = ModelState.MODEL_WRITTEN_TEXT
                else:
                    state = ModelState.USER_WRITTEN_TEXT

    return written_tokens, written_logprobs


async def compute_nll(fireworks: AsyncFireworks, model: str, data: list[dict[str, str]]) -> float:
    # construct chat completion using together api
    response = fireworks.chat.completions.create(
        model=model,
        messages=data,
        max_tokens=0,
        logprobs=1,
        echo=True,
    )
    tokens, logprobs = get_model_written_tokens_logprobs(response.choices[0].logprobs)

    print(tokens)
    print(logprobs)

    return -np.mean(logprobs)




In [13]:
fw = AsyncFireworks(api_key=pathlib.Path("~/tokens/fireworks").expanduser().read_text().strip())

In [15]:
await compute_nll(fw, MODELS[0], [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hello, my name is bob"}])

['\n\n', 'hello', ',', ' my', ' name', ' is', ' bob', '', '\n\n']
[0.0, -7.796875, -3.38671875, -13.59375, -0.46459961, -0.00195503, -8.4921875, -1.97265625, 0.0]


3.9676380155555555