In [None]:
# type: ignore

# Fireworks Supervised Fine-Tuning

This recipe allows TensorZero users to fine-tune open-source LLMs using their own data.
Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.
We follow the Fireworks [docs](https://docs.fireworks.ai/fine-tuning/fine-tuning-models) on fine-tuning a model.


To get started:

- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL="http://chuser:chpassword@localhost:8123/tensorzero"`
- You'll also need to [install](https://docs.fireworks.ai/tools-sdks/firectl/firectl) the CLI tool `firectl` on your machine and sign in with `firectl signin`. You can test that this all worked with `firectl whoami`.
- Update the following parameters:


In [8]:
import os

CLICKHOUSE_URL = os.getenv("TENSORZERO_CLICKHOUSE_URL")
CLICKHOUSE_URL = "http://chuser:chpassword@localhost:8123/tensorzero_ui_fixtures"
assert CLICKHOUSE_URL is not None, "TENSORZERO_CLICKHOUSE_URL is not set"

In [18]:
CONFIG_PATH = "../../../../examples/data-extraction-ner/config/tensorzero.toml"

FUNCTION_NAME = "extract_entities"

# You can also fine-tune on demonstrations sent to TensorZero 
METRIC_NAME = "jaccard_similarity"

# The name of the variant to use to grab the templates used for fine-tuning
TEMPLATE_VARIANT_NAME = "gpt_4o_mini"  # It's OK that this variant uses a different model than the one we're fine-tuning

# If the metric is a float metric, you can set the threshold to filter the data
FLOAT_METRIC_THRESHOLD = 0.5

# Number of epochs to train for
NUM_EPOCHS = 1

# Maximum number of samples to use for fine-tuning (for Fireworks, NUM_EPOCHS * MAX_SAMPLES should be <= 3,000,000)
MAX_SAMPLES = 100_000

# The name of the model to fine-tune (supported models: https://docs.fireworks.ai/fine-tuning/fine-tuning-models#supported-base-models)
MODEL_NAME = "accounts/fireworks/models/llama-v3p1-8b-instruct"

# At the time of writing, Fireworks does not support tool call content blocks in assistant messages. Or the tool role.
# We will drop these invalid messages from the dataset by default.
# You can set this to False to keep the invalid messages in the dataset.
DROP_INVALID_MESSAGES = True

In [27]:
import json
import os
import subprocess
import tempfile
import warnings
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Optional

import toml
from clickhouse_connect import get_client
from IPython.display import clear_output
from minijinja import Environment
from tensorzero.util import uuid7
from tensorzero import TensorZeroGateway, FloatMetricFilter, RenderedSample

Initialize the embedded TensorZero client


In [20]:
t0 = TensorZeroGateway.build_embedded(
    config_file=CONFIG_PATH,
    clickhouse_url=CLICKHOUSE_URL,
)

Load the TensorZero configuration file.


In [21]:
config_path = Path(CONFIG_PATH)

assert config_path.exists(), f"{CONFIG_PATH} does not exist"
assert config_path.is_file(), f"{CONFIG_PATH} is not a file"

with config_path.open("r") as f:
    config = toml.load(f)

Query inference data


In [25]:
filters = FloatMetricFilter(metric_name=METRIC_NAME, value=FLOAT_METRIC_THRESHOLD, comparison_operator=">")
stored_samples = t0.experimental_list_inferences(function_name=FUNCTION_NAME, 
                                                 filters=filters,
                                                 output_source="inference", # could also be "demonstration"
                                                 limit=MAX_SAMPLES)


Template the data using the variant we chose above.


In [26]:
rendered_samples = t0.experimental_render_samples(stored_samples=stored_samples, variants={FUNCTION_NAME: TEMPLATE_VARIANT_NAME})

In [33]:
rendered_samples[0].input.messages[0].role

'user'

Convert the rendered samples to the format Fireworks expects. This is handled automatically with our built-in `experimental_launch_optimization` method but we do it explicitly here.


In [43]:
def warning_message(role: str) -> str:
    return (
        f"Fireworks does not support multiple content blocks per message. "
        f"We have chosen to concatenate the text across all content blocks for the message with role '{role}'. "
        f"You may want to manually review this behavior."
    )


def render_message(message) -> Optional[List[Dict[str, Any]]]:
    role = message.role
    assert role in ["user", "assistant"], f"Invalid role: {role}"
    content: List[Dict[str, Any]] = []
    tool_calls: List[Dict[str, Any]] = []
    rendered_messages: List[Dict[str, Any]] = []

    for content_block in message.content:
        if content_block.type not in ["text", "raw_text"] and DROP_INVALID_MESSAGES:
            warnings.warn(
                f"Fireworks may not support content block type: {content_block['type']}, dropping example.",
                UserWarning,
            )
            return None
        if content_block.type == "text":
            parsed_content = content_block.text
            content.append({"type": "text", "text": parsed_content})
        elif content_block.type == "raw_text":
            content.append({"type": "text", "text": content_block.value})
        elif content_block.type == "thought":
            content.append(
                {"type": "text", "text": f"<think>{content_block.text}</think>"}
            )
        elif (
            content_block.type == "tool_call"
            and role == "assistant"
            and not DROP_INVALID_MESSAGES
        ):
            warnings.warn(
                "Fireworks may not support tool calls in assistant messages. Including it may cause the fine-tuning job to fail.",
                UserWarning,
            )
            tool_calls.append(
                {
                    "function": {
                        "arguments": json.dumps(content_block["arguments"]),
                        "name": content_block["name"],
                    },
                    "id": content_block["id"],
                    "type": "function",
                }
            )
        elif (
            content_block.type == "tool_result"
            and role == "user"
            and not DROP_INVALID_MESSAGES
        ):
            warnings.warn(
                "Fireworks may not support tool results in user messages. Including it may cause the fine-tuning job to fail.",
                UserWarning,
            )
            # Tool results get priority so that they follow the tool call in the conversation.
            # Any other "user" content will be appended in another message below.
            rendered_messages.append(
                {
                    "role": "tool",
                    "tool_call_id": content_block["id"],
                    "content": content_block["result"],
                }
            )
        else:
            warnings.warn(
                f"We do not support content block type: {content_block['type']}, dropping example.",
                UserWarning,
            )
            return None

    if content or tool_calls:
        role_message: Dict[str, Any] = {"role": role}
        if content:
            if len(content) > 1:
                warnings.warn(warning_message(role), UserWarning)
            role_message["content"] = "\n".join([c["text"] for c in content])
        if tool_calls:
            role_message["tool_calls"] = tool_calls
        rendered_messages.append(role_message)

    return rendered_messages


def render_output(
    output
) -> Optional[Dict[str, Any]]:
    """
    Parses the assistant message from an observation using the provided function configuration.
    """
    content: List[Dict[str, Any]] = []
    tool_calls: List[Dict[str, Any]] = []

    for content_block in output:
        if content_block.type != "text" and DROP_INVALID_MESSAGES:
                warnings.warn(
                    f"Fireworks may not support content block type: {content_block['type']}, dropping example.",
                    UserWarning,
                )
                return None
        if content_block.type == "text":
            content.append({"type": "text", "text": content_block.text})
        elif content_block.type == "thought":
                content.append(
                    {"type": "text", "text": f"<think>{content_block.text}</think>"}
                )
        elif content_block.type == "tool_call" and not DROP_INVALID_MESSAGES:
                warnings.warn(
                    "Fireworks may not support tool calls in assistant messages. Including it may cause the fine-tuning job to fail.",
                    UserWarning,
                )
                tool_calls.append(
                    {
                        "function": {
                            "arguments": json.dumps(content_block["arguments"]),
                            "name": content_block["name"],
                        },
                        "id": content_block["id"],
                        "type": "function",
                    }
                )
        else:
            warnings.warn(
                f"We do not support content block type: {content_block.type}, dropping example.",
                UserWarning,
            )
            return None

    # Once we finish collecting all blocks, create one assistant message.
    output_message: Dict[str, Any] = {"role": "assistant"}
    if content:
        if len(content) > 1:
            warnings.warn(warning_message("assistant"), UserWarning)
        output_message["content"] = "\n".join([c["text"] for c in content])
    if tool_calls:
        output_message["tool_calls"] = tool_calls

    return output_message


def rendered_sample_to_fireworks(sample: RenderedSample) -> List[Dict[str, Any]]:
    function_input = sample.input

    rendered_messages = []

    # Add the system message to the rendered messages
    # If there is data passed in or a system template there must be a system message
    system = function_input.system
    if system:
        rendered_messages.append({"role": "system", "content": system})

    # Add the input messages to the rendered messages
    for message in function_input.messages:
        rendered_message = render_message(message)
        if rendered_message is None:
            # `render_message` will return None if the message contains an unknown or unsupported content block.
            # The entire example is dropped if this is the case.
            return None
        rendered_messages.extend(rendered_message)

    # Add the output to the messages
    rendered_output = render_output(sample.output)
    if rendered_output is None:
        # `render_output` will return None if the output contains an unknown or unsupported content block.
        # The entire example is dropped if this is the case.
        return None
    rendered_messages.append(rendered_output)

    return rendered_messages

In [45]:
fireworks_samples = []
for sample in rendered_samples:
    rendered_sample = rendered_sample_to_fireworks(sample)
    if rendered_sample is not None:
        fireworks_samples.append(rendered_sample)

print(f"Found {len(fireworks_samples)} samples to fine-tune on")

# We can now fine-tune on the samples

# We can now fine-tune on the samples

Found 69 samples to fine-tune on


We'll write the conversational messages to a temporary file for the Fireworks CLI


In [46]:
dataset_id = f"t0-{uuid7()}"

with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl") as f:
    for sample in fireworks_samples:
        f.write((json.dumps(sample) + "\n").encode("utf-8"))

    dataset_path = f.name
    result = subprocess.run(
        ["firectl", "create", "dataset", dataset_id, dataset_path], capture_output=True
    )
print(result.stdout)

b''


In [47]:
result = subprocess.run(["firectl", "get", "dataset", dataset_id], capture_output=True)
print(result.stdout.decode("utf-8"))




In [48]:
def get_job_id(stdout: str) -> str:
    for line in stdout.splitlines():
        if line.strip().startswith("Name:"):
            return line.split("/")[-1].strip()
    raise ValueError("Job ID not found in output")

Now we start the fine-tuning job. This cell will block until the job is done.


In [49]:
command = [
    "firectl",
    "create",
    "sftj",
    "--display-name",
    f"tensorzero-ft-job-{dataset_id}",
    "--dataset",
    dataset_id,
    "--base-model",
    MODEL_NAME,
]

if NUM_EPOCHS is not None:
    command.append("--epochs")
    command.append(str(NUM_EPOCHS))

print("Command: ", " ".join(command))

result = subprocess.run(command, capture_output=True)

if result.returncode != 0:
    print(result.stderr.decode("utf-8"))
else:
    stdout = result.stdout.decode("utf-8")
    print(stdout)
    job_id = get_job_id(stdout)
    print(f"job_id: {job_id}")

Command:  firectl create sftj --display-name tensorzero-ft-job-t0-0198230b-0139-7941-a41a-9627d174d363 --dataset t0-0198230b-0139-7941-a41a-9627d174d363 --base-model accounts/fireworks/models/llama-v3p1-8b-instruct --epochs 1
2025/07/19 10:16:32 There are updates available.
	Current version: 1.4.3
	Latest version: 1.5.11

	To upgrade to the latest version, run
  	$ sudo firectl upgrade

	
2025/07/19 10:16:32 Failed to execute: unknown flag: --display-name



In [50]:
while True:
    clear_output(wait=True)

    try:
        command = ["firectl", "get", "sftj", job_id]
        result = subprocess.run(command, capture_output=True)
        stdout = result.stdout.decode("utf-8")
        print(stdout)
    except Exception as e:
        print(f"Error: {e}")

    if "State: JOB_STATE_FAILED" in stdout:
        raise ValueError("Fine-tuning job failed")

    if "State: JOB_STATE_COMPLETED" in stdout:
        break

    sleep(5)

Error: name 'job_id' is not defined


NameError: name 'stdout' is not defined

In [None]:
def get_model_id(stdout: str) -> str:
    for line in stdout.splitlines():
        if line.strip().startswith("Output Model:"):
            return line.split(":")[1].strip()
    raise ValueError("Model ID not found in output")


model_id = get_model_id(stdout)

assert model_id

model_id

Now that the model is done training, we need to [deploy](https://docs.fireworks.ai/fine-tuning/fine-tuning-models#deploying-and-using-a-model) it to Fireworks serverless inference. If you need high or guaranteed throughput you can also deploy the model to [reserved capacity](https://docs.fireworks.ai/deployments/reservations) or an on-demand [deployment](https://docs.fireworks.ai/guides/ondemand-deployments).


In [None]:
command = ["firectl", "deploy", model_id]
print(" ".join(command))
result = subprocess.run(command, capture_output=True)
if result.returncode != 0:
    print(result.stderr.decode("utf-8"))
else:
    stdout = result.stdout.decode("utf-8")
    print(stdout)

In [None]:
def get_model_identifier(model_id: str) -> str:
    command = ["firectl", "get", "model", model_id]
    result = subprocess.run(command, capture_output=True)
    stdout = result.stdout.decode("utf-8")
    for line in stdout.splitlines():
        if line.strip().startswith("Name:"):
            return line.split(":")[1].strip()
    raise ValueError("Model identifier not found in output")


model_identifier = get_model_identifier(model_id)

assert model_identifier

model_identifier

Once the fine-tuning job is complete, you can add the fine-tuned model to your config file.


In [None]:
model_config = {
    "models": {
        model_identifier: {
            "routing": ["fireworks"],
            "providers": {
                "fireworks": {"type": "fireworks", "model_name": model_identifier}
            },
        }
    }
}

print(toml.dumps(model_config))

Finally, add a new variant to your function to use the fine-tuned model.


In [None]:
variant_config = {
    "type": "chat_completion",
    "weight": 0,
    "model": model_identifier,
}

system_template = variant.get("system_template")
if system_template:
    variant_config["system_template"] = system_template

user_template = variant.get("user_template")
if user_template:
    variant_config["user_template"] = user_template

assistant_template = variant.get("assistant_template")
if assistant_template:
    variant_config["assistant_template"] = assistant_template

full_variant_config = {
    "functions": {FUNCTION_NAME: {"variants": {model_identifier: variant_config}}}
}

print(toml.dumps(full_variant_config))

You're all set!

You can change the weight to enable a gradual rollout of the new model.

You might also add other parameters (e.g. `max_tokens`, `temperature`) to the variant section in the config file.
