# Using Intrinsics Directly on the Hugging Face Transformers Library

This notebook demonstrates how to use the shared input and output processing code for
intrinsics when performing model inference directly on top of models from the 
[Transformers library](https://huggingface.co/docs/transformers/en/index).

Note that running inference in this way is significantly slower than using `vLLM` or
another OpenAI-compatible scalable inference engine.

In [None]:
import pathlib
import itertools
import json
import granite_common.util
import torch
from granite_common.base.types import ChatCompletionResponseChoice
import os

## Constants

Change the value of the constants `intrinsic_name` and `base_model` in the cell that 
follows to change which intrinsic will be demonstrated in the remainder of this notebook.

Other constants will automatically adjust accordingly.

In [None]:
# Change the following two parameters as needed
intrinsic_name = "answerability"
base_model_name = "granite-3.3-8b-instruct"
# base_model_name = "gpt-oss-20b"
use_alora = False

use_cuda = True  # Set to False to use default PyTorch device for this machine + model

#######################################################################################
# The code below adjusts the remaining constants according to the chosen intrinsic.

TESTDATA_DIR = "../tests/granite_common/intrinsics/rag/testdata"
KNOWN_INTRINSICS = [
    "answerability",
    "answer_relevance_classifier",
    "answer_relevance_rewriter",
    "citations",
    "context_relevance",
    "hallucination_detection",
    "query_rewrite",
    "requirement_check",
    "uncertainty",
]
INTRINSICS_WITH_LOCAL_YAML_FILES = []
MODEL_TO_CONSTRAINED_DECODING_PREFIX = {
    # Some base models generate their own generation prompts, for example for thinking
    # or for OpenAI channel selection. To use constrained decoding with intrinsics
    # trained on these models, the code we use in this notebook adds the most common
    # generation prompt to the chat template's prompt when applicable.
    "gpt-oss-20b": (
        "<|channel|>analysis<|message|><|end|>"
        "<|start|>assistant<|channel|>final<|message|>"
    )
}
constrained_decoding_prefix = MODEL_TO_CONSTRAINED_DECODING_PREFIX.get(base_model_name)

io_yaml_file = None  # None -> load from Hugging Face Hub
lora_dir = None  # None --> load from Hugging Face Hub
request_json_file = f"{TESTDATA_DIR}/input_json/{intrinsic_name}.json"

# Include local JSON file with arguments if that file is present.
maybe_arg_file = f"{TESTDATA_DIR}/input_args/{intrinsic_name}.json"
arg_file = maybe_arg_file if os.path.exists(maybe_arg_file) else None

# Selectively override defaults
if intrinsic_name == "answerability":
    request_json_file = f"{TESTDATA_DIR}/input_json/answerable.json"
elif intrinsic_name in INTRINSICS_WITH_LOCAL_YAML_FILES:
    # Some io.yaml files not yet delivered to Hugging Face Hub
    io_yaml_file = f"{TESTDATA_DIR}/input_yaml/{intrinsic_name}.yaml"
elif intrinsic_name not in KNOWN_INTRINSICS:
    raise ValueError(f"Unrecognized intrinsic name '{intrinsic_name}'")

# TEMPORARY until we have gpt-oss checkpoints on HF Hub
if base_model_name == "gpt-oss-20b":
    # Use local copy of repo with gpt-oss LoRAs
    peft_type = "alora" if use_alora else "lora"
    lora_dir = pathlib.Path(
        f"../../intrinsics-lib/{intrinsic_name}/{peft_type}/{base_model_name}"
    )
    io_yaml_file = lora_dir / "io.yaml"


if io_yaml_file is None:
    # Fetch IO configuration file from Hugging Face Hub
    io_yaml_file = granite_common.intrinsics.util.obtain_io_yaml(
        intrinsic_name, base_model_name
    )

if lora_dir is None:
    # Fetch LoRA directory from Hugging Face Hub
    lora_dir = granite_common.intrinsics.util.obtain_lora(
        intrinsic_name, base_model_name, alora=use_alora
    )

if not os.path.exists(lora_dir):
    raise ValueError(f"LoRA directory {lora_dir} does not exist on this machine.")

# Print the variables we just set
print(f"{lora_dir=}")
print(f"{io_yaml_file=}")
print(f"{request_json_file=}")
print(f"{arg_file=}")

## Instantiate the input and output processing classes

The constructors for the classes `IntrinsicsRewriter` and `IntrinsicsResultProcessor`
serve as factory methods to produce input and output processors, respectively, for 
a given intrinsic.

In [None]:
print(
    f"Instantiating input and output processing from configuration file:\n"
    f"{io_yaml_file}"
)

rewriter = granite_common.IntrinsicsRewriter(config_file=io_yaml_file)
result_processor = granite_common.IntrinsicsResultProcessor(config_file=io_yaml_file)

## Perform input processing

The cells that follow load an example OpenAI-compatible chat completion request from
a local file, then show how to apply input processing to the request.

In [None]:
# Read original request from the appropriate file
print(f"Loading request data from {request_json_file}")
with open(request_json_file, encoding="utf-8") as f:
    request_json_str = f.read()
request_json = json.loads(request_json_str)

# Some parameters like model name aren't kept in the JSON files that we use for testing.
# Apply appropriate values for those parameters.
request_json["model"] = intrinsic_name
request_json["temperature"] = 0.0

print("Original request:")
print(json.dumps(request_json, indent=2))

In [None]:
# Some intrinsics take one or more additional arguments besides the target chat
# completion request. Load the additional arguments from a file if that is the case.
intrinsic_kwargs = {}
if arg_file is not None:
    with open(arg_file, encoding="utf8") as file:
        intrinsic_kwargs = json.load(file)
    print(f"Using additional arguments:\n{intrinsic_kwargs}")

In [None]:
# Run request through input processing.
rewritten_request = rewriter.transform(request_json, **intrinsic_kwargs)

print("Request after input processing:")
print(rewritten_request.model_dump_json(indent=2))

## Running inference

Passing a request through the input processing `IntrinsicsRewriter.transform()` 
turns the request into something that can be sent directly to an OpenAI-compatible
inference endpoint for the intrinsic.

The Transformers library does not have an OpenAI-compatible inference API, so the
cells that follow use functions provided by the `granite-common` library to convert
the OpenAI-compatible request into the proprietary format of the Transformers
library.

In [None]:
# Load the base model and merge LoRA weights.
# Unlike vLLM, the Transformers library does not have a facility for dynamically
# switching between many LoRA adapters.
model, tokenizer = granite_common.util.load_transformers_lora(lora_dir)
if use_cuda:
    model = model.cuda()

In [None]:
# Convert the chat completion request into a the Transformers library's proprietary
# format.
generate_input, other_input = (
    granite_common.util.chat_completion_request_to_transformers_inputs(
        rewritten_request,
        tokenizer,
        model,
        # Note that this last argument is currently only needed for gpt-oss-20b
        constrained_decoding_prefix=constrained_decoding_prefix,
    )
)

generate_input

In [None]:
# Use the Transformers library's APIs to generate one or more completions,
# then convert those completions into OpenAI-compatible chat completion
responses = granite_common.util.generate_with_transformers(
    tokenizer, model, generate_input, other_input
)
print(responses.model_dump_json(indent=2))

## Post-process inference results

The raw output of some intrinsics requires some additional postprocessing to turn it 
into a form that is easy to consume in an application. This postprocessing occurs in
the method `IntrinsicsResultProcessor.transform()`. 

The cells that follow show how to use this method to transform the raw output of the
`chat.completions.create()` API call into the intrinsic's application-level output
value.

By convention, this application-level output value is returned in the same format as a
chat completions request result. Code in the `generate_with_transformers()` function 
has already converted the results into that format.


In [None]:
transformed_responses = result_processor.transform(responses, rewritten_request)
print(transformed_responses.model_dump_json(indent=4))

In [None]:
# Parse and pretty-print the JSON from the "content" field of the generated
# message
json.loads(transformed_responses.choices[0].message.content)

## Show low-level results

Sometimes you may need to see the raw prompts being sent to the Transformers model.

The cells that follow contain the code of `generate_with_transformers()` with additional debug printouts.

In [None]:
# XGrammar logit processors are single-use, so we need to redo the translation
# from a chat completion request.
generate_input, other_input = (
    granite_common.util.chat_completion_request_to_transformers_inputs(
        rewritten_request,
        tokenizer,
        model,
        # Note that this last argument is currently only needed for gpt-oss-20b
        constrained_decoding_prefix=constrained_decoding_prefix,
    )
)
generate_input

## 

In [None]:
# Log the prompt
print(f"Prompt string:\n{tokenizer.decode(generate_input['input_tokens'][0])}")

In [None]:
# Input tokens must be passed to generate() as a positional argument, not a named
# argument.
input_tokens = generate_input["input_tokens"]
generate_input = generate_input.copy()
del generate_input["input_tokens"]


generate_result = model.generate(input_tokens, **generate_input)

# Result is a a 2D tensor of shape (num responses, prompt + max generated tokens)
# containing tokens, plus a tuple of <max generated tokens> tensors of shape
# (num beams, vocab size) containing scores.
# This is of course not a usable format for downstream processing.
# Start by stripping off the prompt, leaving us with a tensor of shape
# (num responses, max generated tokens)
num_prompt_tokens = input_tokens.shape[1]
num_responses = generate_result.sequences.shape[0]
generated_tokens = generate_result.sequences[:, num_prompt_tokens:]

generated_scores = (
    None
    if generate_result.scores is None
    else (torch.stack(generate_result.scores).swapaxes(0, 1)[:num_responses])
)

# Iterate over the responses, stripping off EOS tokens
choices = []
for i in range(num_responses):
    response_tokens = generated_tokens[i]

    if tokenizer.eos_token_id in response_tokens:
        # Strip off everything after the first EOS token.
        # Pytorch syntax for finding the first EOS is a bit funky.
        eos_ix = (
            (response_tokens == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item()
        )
        response_tokens = response_tokens[:eos_ix]

    response_string = tokenizer.decode(response_tokens)
    print(f"Raw response {i}: {response_string}")

    # The decode() method doesn't return offsets.
    # The only supported API to get offsets is to retokenize the string and hope you
    # get back the same tokenization.
    # This supported API doesn't work reliably, so we fall back on the unsupported
    # method of pulling token lengths out of the tokenizer.
    ends = list(
        itertools.accumulate([len(s) for s in tokenizer.batch_decode(response_tokens)])
    )
    begins = [0] + ends[:-1]
    token_offsets = list(zip(begins, ends, strict=True))

    if generated_scores is None:
        logprobs_content = None
    else:
        response_scores = generated_scores[i]

        # Scores come back as raw logits. You need to decode them to produce
        # logprobs. For consistency with the OpenAI output format, we need to
        # decode twice: Once to get the probability of the returned token and a
        # second time to get the top k logprobs. As with the OpenAI APIs, the
        # returned token may or may not be included in the top k results.
        all_logprobs = torch.log_softmax(response_scores.to(torch.float32), 1)
        chosen_token_logprobs = [
            all_logprobs[token_ix][response_tokens[token_ix]].item()
            for token_ix in range(len(response_tokens))
        ]
        token_strings = [response_string[begin:end] for begin, end in token_offsets]
        token_bytes = [list(s.encode("utf-8")) for s in token_strings]

        # Transformers has no notion of top-k logprobs, so the parameter that
        # triggers that post-processing is passed via other_input.
        if "top_logprobs" not in other_input:
            top_logprobs = [[] for _ in range(len(token_strings))]
        else:  # if "top_logprobs" in other_input:
            top_k_values, top_k_indices = torch.topk(
                torch.nan_to_num(all_logprobs, float("-inf")),
                other_input["top_logprobs"],
            )
            top_k_token_strs = [
                [tokenizer.decode(t) for t in row_i] for row_i in top_k_indices
            ]
            top_logprobs = [
                [
                    {
                        "token": s,
                        "bytes": list(s.encode("utf8")),
                        "logprob": lp.item(),
                    }
                    for s, lp in zip(strs, lps, strict=True)
                ]
                for strs, lps in zip(top_k_token_strs, top_k_values, strict=True)
            ]

        logprobs_content = [
            {
                "token": token_strings[i],
                "bytes": token_bytes[i],
                "logprob": chosen_token_logprobs[i],
                "top_logprobs": top_logprobs[i],
            }
            for i in range(len(response_tokens))
        ]

    response_choice_value = {
        "index": i,
        "message": {"content": response_string, "role": "assistant"},
    }
    if logprobs_content is not None:
        response_choice_value["logprobs"] = {"content": logprobs_content}
    response_choice = ChatCompletionResponseChoice.model_validate(response_choice_value)
    choices.append(response_choice)

# END code from from generate_with_transformers()

In [None]:
for i, choice in enumerate(choices):
    print(f"Choice {i}:\n{choice.model_dump_json(indent=2)}")