# Examples of operations on the classes in the "io" package.

This notebook should eventually turn into documentation.

In [None]:
import json

import aconfig

In [None]:
INPUT_JSON_STRS = [
    """
{
    "messages":
    [
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
        {"role": "user", "content": "I'd like to show off how chat templating works!"}
    ]
}
""",
    """
{
    "messages":
    [
        {"role": "user", "content": "How much wood could a wood chuck chuck?"}
    ],
    "thinking": true
}
""",
]

input_json_str = INPUT_JSON_STRS[0]
print(input_json_str)

In [None]:
input_json = json.loads(input_json_str)
input_json

In [None]:
from granite_io.io.base import ChatCompletionInputs

input_obj = ChatCompletionInputs.model_validate(input_json)
input_obj

In [None]:
from granite_io.io.granite_3_3.input_processors.granite_3_3_input_processor import (
    Granite3Point3Inputs,
)  # noqa: E501

granite_input_obj = Granite3Point3Inputs.model_validate(input_obj.model_dump())
granite_input_obj

In [None]:
reconstituted_json = input_obj.model_dump_json(indent=4)
print(reconstituted_json)

In [None]:
import transformers

GRANITE_MODEL_STR = "ibm-granite/granite-3.3-2b-instruct"

tokenizer = transformers.AutoTokenizer.from_pretrained(GRANITE_MODEL_STR)

input_kwargs = input_json.copy()
del input_kwargs["messages"]
transformers_str = tokenizer.apply_chat_template(
    input_json["messages"], **input_kwargs, tokenize=False, add_generation_prompt=True
)

print(transformers_str)

In [None]:
from granite_io.io.granite_3_3.granite_3_3 import Granite3Point3InputOutputProcessor

inputs = ChatCompletionInputs.model_validate_json(input_json_str)
io_proc_str = Granite3Point3InputOutputProcessor().inputs_to_string(inputs)
print(io_proc_str)

In [None]:
# Load a model onto the GPU and wrap it in an I/O processor.

import torch

from granite_io.io.granite_3_3.granite_3_3 import Granite3Point3InputOutputProcessor
from granite_io.backend.transformers import TransformersBackend

if torch.cuda.is_available():
    device_name = "cuda"
elif torch.backends.mps.is_available():
    device_name = "mps"
else:
    device_name = "cpu"
    # CPU mode; prevent thrashing
    torch.set_num_threads(4)

backend = TransformersBackend(
    aconfig.Config(
        {"model_name": GRANITE_MODEL_STR, "device": device_name},
        override_env_vars=False,
    ),
)
io_processor = Granite3Point3InputOutputProcessor(backend=backend)

In [None]:
result = io_processor.create_chat_completion(inputs)
print(result.results[0].next_message.content)