# Together Supervised Fine-Tuning

This recipe allows TensorZero users to fine-tune Together models using their own data.
Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.


To get started:

- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL="http://chuser:chpassword@localhost:8123/tensorzero"`
- Set the `TOGETHER_API_KEY` environment variable.
- Update the following parameters:


In [2]:
CONFIG_PATH = "../../../../examples/data-extraction-ner/config/tensorzero.toml"

FUNCTION_NAME = "extract_entities"

METRIC_NAME = "exact_match"

# The name of the variant to use to grab the templates used for fine-tuning
TEMPLATE_VARIANT_NAME = "gpt_4o_mini"  # It's OK that this variant uses a different model than the one we're fine-tuning

# If the metric is a float metric, you can set the threshold to filter the data
FLOAT_METRIC_THRESHOLD = 0.5

# Fraction of the data to use for validation
VAL_FRACTION = 0.2

# Maximum number of samples to use for fine-tuning
MAX_SAMPLES = 100_000

# The name of the model to fine-tune (supported models: https://docs.together.ai/docs/fine-tuning-models)
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference"

In [3]:
import json
import os
import subprocess
import tempfile
import time
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, List

import numpy as np
import pandas as pd
import requests
import toml
from clickhouse_connect import get_client
from IPython.display import clear_output
from minijinja import Environment

Load the TensorZero configuration file.


In [4]:
config_path = Path(CONFIG_PATH)

assert config_path.exists(), f"{CONFIG_PATH} does not exist"
assert config_path.is_file(), f"{CONFIG_PATH} is not a file"

with config_path.open("r") as f:
    config = toml.load(f)

Retrieve the metric configuration.


In [5]:
assert "metrics" in config, "No `[metrics]` section found in config"
assert METRIC_NAME in config["metrics"], (
    f"No metric named `{METRIC_NAME}` found in config"
)

metric = config["metrics"][METRIC_NAME]

metric

{'type': 'boolean', 'level': 'inference', 'optimize': 'max'}

Retrieve the configuration for the variant with the templates we'll use for fine-tuning.


In [6]:
assert "functions" in config, "No `[functions]` section found in config"
assert FUNCTION_NAME in config["functions"], (
    f"No function named `{FUNCTION_NAME}` found in config"
)
assert "variants" in config["functions"][FUNCTION_NAME], (
    f"No variants section found for function `{FUNCTION_NAME}`"
)
assert TEMPLATE_VARIANT_NAME in config["functions"][FUNCTION_NAME]["variants"], (
    f"No variant named `{TEMPLATE_VARIANT_NAME}` found in function `{FUNCTION_NAME}`"
)

function_type = config["functions"][FUNCTION_NAME]["type"]
variant = config["functions"][FUNCTION_NAME]["variants"][TEMPLATE_VARIANT_NAME]

variant

{'type': 'chat_completion',
 'model': 'openai::gpt-4o-mini-2024-07-18',
 'system_template': 'functions/extract_entities/initial_prompt/system_template.minijinja',
 'weight': 0}

Retrieve the system, user, and assistant templates in the variant (if any), and initialize a minijinja environment with them.


In [7]:
templates = {}

if "assistant_template" in variant:
    assistant_template_path = config_path.parent / variant["assistant_template"]
    with assistant_template_path.open("r") as f:
        templates["assistant"] = f.read()

if "system_template" in variant:
    system_template_path = config_path.parent / variant["system_template"]
    with system_template_path.open("r") as f:
        templates["system"] = f.read()

if "user_template" in variant:
    user_template_path = config_path.parent / variant["user_template"]
    with user_template_path.open("r") as f:
        templates["user"] = f.read()

env = Environment(templates=templates)

Initialize the ClickHouse client.


In [None]:
assert "TENSORZERO_CLICKHOUSE_URL" in os.environ, (
    "TENSORZERO_CLICKHOUSE_URL environment variable not set"
)

clickhouse_client = get_client(dsn=os.environ["TENSORZERO_CLICKHOUSE_URL"])

Determine the ClickHouse table name for the function.


In [11]:
inference_table_name = {"chat": "ChatInference", "json": "JsonInference"}.get(
    function_type
)

if inference_table_name is None:
    raise ValueError(f"Unsupported function type: {function_type}")

Determine the ClickHouse table name for the metric.


In [12]:
feedback_table_name = {
    "float": "FloatMetricFeedback",
    "boolean": "BooleanMetricFeedback",
}.get(metric["type"])

if feedback_table_name is None:
    raise ValueError(f"Unsupported metric type: {metric['type']}")

Determine the correct join key to use for the metric on the inference table.


In [13]:
inference_join_key = {
    "episode": "episode_id",
    "inference": "id",
}.get(metric["level"])

if inference_join_key is None:
    raise ValueError(f"Unsupported metric level: {metric['level']}")

Query the inferences and feedback from ClickHouse.

If the metric is a float metric, we need to filter the data based on the threshold.


In [14]:
assert "optimize" in metric, "Metric is missing the `optimize` field"

threshold = FLOAT_METRIC_THRESHOLD if metric["type"] == "float" else 0.5
comparison_operator = ">=" if metric["optimize"] == "max" else "<="

query = f"""
SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    f.value,
    i.episode_id
FROM 
    {inference_table_name} i
JOIN 
    (SELECT
        target_id,
        value,
        ROW_NUMBER() OVER (PARTITION BY target_id ORDER BY timestamp DESC) as rn
    FROM 
        {feedback_table_name}
    WHERE
        metric_name = %(metric_name)s
        AND value {comparison_operator} %(threshold)s
    ) f ON i.{inference_join_key} = f.target_id and f.rn = 1
WHERE 
    i.function_name = %(function_name)s
LIMIT %(max_samples)s
"""

params = {
    "function_name": FUNCTION_NAME,
    "metric_name": METRIC_NAME,
    "comparison_operator": comparison_operator,
    "threshold": threshold,
    "max_samples": MAX_SAMPLES,
}

df = clickhouse_client.query_df(query, params)

df.head()

Unnamed: 0,variant_name,input,output,value,episode_id
0,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""A. Stewart\"",\""Asif Muj...",True,01953e74-a43b-7572-863b-e5cdba48cfc5
1,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True,01953e74-a273-74f0-8724-8a8fb98f89a5
2,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Steve Stricker\""],\""org...",True,01953e74-f56d-7c81-930d-8933ff5855eb
3,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True,01953e74-4b09-72c0-9b50-af4785c70a53
4,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True,01953e74-c6a1-7ca0-a79c-ab2e382ad629


Render the inputs using the templates.


In [15]:
def render_message(content: List[Dict[str, Any]], role: str) -> str:
    assert role in ["user", "assistant"], f"Invalid role: {role}"

    if len(content) != 1:
        raise ValueError(f"Message must have exactly one content block: {content}")

    if content[0]["type"] != "text":
        raise ValueError(f"Content block must be of type text: {content}")

    content = content[0]["value"]

    if isinstance(content, str):
        return content
    else:
        return env.render_template(role, **content)


def sample_to_conversational_messages(sample) -> List[Dict[str, str]]:
    function_input = json.loads(sample["input"])

    rendered_messages = []

    # Add the system message to the rendered messages
    # If there is data passed in or a system template there must be a system message
    system = function_input.get("system", {})
    if len(system) > 0 or system_template_path:
        if system_template_path:
            system_message = env.render_template("system", **system)
            rendered_messages.append({"role": "system", "content": system_message})
        else:
            rendered_messages.append({"role": "system", "content": system})

    # Add the input messages to the rendered messages
    for message in function_input["messages"]:
        rendered_message = render_message(message["content"], message["role"])
        rendered_messages.append({"role": message["role"], "content": rendered_message})

    # Add the output to the messages
    output = json.loads(sample["output"])

    if function_type == "chat":
        if len(output) != 1:
            raise ValueError(f"Output {output} must have exactly one content block.")

        if output[0]["type"] != "text":
            raise ValueError(f"Output {output} must be a text block.")

        rendered_messages.append({"role": "assistant", "content": output[0]["text"]})
    elif function_type == "json":
        rendered_messages.append({"role": "assistant", "content": output["raw"]})
    else:
        raise ValueError(f"Unsupported function type: {function_type}")

    return {"messages": rendered_messages}


df["conversational_messages"] = df.apply(sample_to_conversational_messages, axis=1)

df.head()

Unnamed: 0,variant_name,input,output,value,episode_id,conversational_messages
0,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""A. Stewart\"",\""Asif Muj...",True,01953e74-a43b-7572-863b-e5cdba48cfc5,"{'messages': [{'role': 'system', 'content': 'Y..."
1,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True,01953e74-a273-74f0-8724-8a8fb98f89a5,"{'messages': [{'role': 'system', 'content': 'Y..."
2,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Steve Stricker\""],\""org...",True,01953e74-f56d-7c81-930d-8933ff5855eb,"{'messages': [{'role': 'system', 'content': 'Y..."
3,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True,01953e74-4b09-72c0-9b50-af4785c70a53,"{'messages': [{'role': 'system', 'content': 'Y..."
4,gpt_4o,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True,01953e74-c6a1-7ca0-a79c-ab2e382ad629,"{'messages': [{'role': 'system', 'content': 'Y..."


Split the data into training and validation sets for fine-tuning.


In [16]:
# Get unique episode_ids
unique_episode_ids = df["episode_id"].unique()

# Shuffle the unique episode_ids
np.random.seed(42)
np.random.shuffle(unique_episode_ids)

# Calculate the split index for episode_ids
split_index = int(len(unique_episode_ids) * (1 - VAL_FRACTION))

# Split the episode_ids into training and validation sets
train_episode_ids = unique_episode_ids[:split_index]
val_episode_ids = unique_episode_ids[split_index:]

# Create training and validation DataFrames based on episode_ids
train_df = df[df["episode_id"].isin(train_episode_ids)]
val_df = df[df["episode_id"].isin(val_episode_ids)]

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Actual validation fraction: {len(val_df) / len(df):.2f}")

Training set size: 122
Validation set size: 31
Actual validation fraction: 0.20


We'll write the training and validation messages to temporary files for the Together CLI

In [18]:
def upload_dataset_to_together(df: pd.DataFrame) -> str:
    with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
        # Write the conversational_messages to the temporary file
        for item in df["conversational_messages"]:
            json.dump(item, f)
            f.write("\n")
        f.flush()

        dataset_path = f.name
        result = subprocess.run(
            ["together", "files", "upload", dataset_path], capture_output=True
        )
        print("Stdout:")
        print(result.stdout.decode())
        print("Stderr:")
        print(result.stderr.decode())
        together_result = json.loads(result.stdout)
        return together_result["id"]


train_file_object_id = upload_dataset_to_together(train_df)
val_file_object_id = upload_dataset_to_together(val_df)

Stdout:
{
    "id": "file-73db59a6-066f-4691-9c12-fcb675611df3",
    "object": "file",
    "created_at": 1740513732,
    "purpose": "fine-tune",
    "filename": "tmpyonjgwqt.jsonl",
    "bytes": 0,
    "line_count": 0,
    "processed": false,
    "FileType": "jsonl"
}

Stderr:
Uploading file tmpyonjgwqt.jsonl: 100%|██████████| 102k/102k [00:00<00:00, 112kB/s]

Stdout:
{
    "id": "file-cffdce5f-4caf-4cc8-bf23-f912d562bcf4",
    "object": "file",
    "created_at": 1740513734,
    "purpose": "fine-tune",
    "filename": "tmpx09mwx15.jsonl",
    "bytes": 0,
    "line_count": 0,
    "processed": false,
    "FileType": "jsonl"
}

Stderr:
Uploading file tmpx09mwx15.jsonl: 100%|██████████| 25.9k/25.9k [00:00<00:00, 43.6kB/s]



Launch the fine-tuning job.


In [19]:
url = "https://api.together.xyz/v1/fine-tunes"
print("MODEL: ", MODEL_NAME)
print("Train: ", train_file_object_id)
print("Val: ", val_file_object_id)

payload = {
    "training_file": train_file_object_id,
    "validation_file": val_file_object_id,
    "model": MODEL_NAME,
    "n_epochs": 1,
    "n_checkpoints": 1,
    "n_evals": 0,
    "batch_size": 16,
    "learning_rate": 0.00001,
    "lr_scheduler": {"lr_scheduler_args": {"min_lr_ratio": 0}},
    "warmup_ratio": 0,
    "max_grad_norm": 1,
    "weight_decay": 0,
    "train_on_inputs": "auto",
    "training_type": {"type": "Lora", "lora_r": 8, "lora_alpha": 32},
}
headers = {
    "accept": "application/json",
    "content-type": "application/json",
    "authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
}

response = requests.post(url, json=payload, headers=headers)
print("Response status: ", response.status_code)
print("Response body: ")
print(response.text)
response_json = json.loads(response.text)
fine_tune_id = response_json["id"]

MODEL:  meta-llama/Meta-Llama-3.1-8B-Instruct-Reference
Train:  file-73db59a6-066f-4691-9c12-fcb675611df3
Val:  file-cffdce5f-4caf-4cc8-bf23-f912d562bcf4
Response status:  200
Response body: 
{"training_file":"file-73db59a6-066f-4691-9c12-fcb675611df3","validation_file":"file-cffdce5f-4caf-4cc8-bf23-f912d562bcf4","model":"meta-llama/Meta-Llama-3.1-8B-Instruct-Reference","suffix":"","n_epochs":1,"n_evals":0,"n_checkpoints":1,"batch_size":16,"training_type":{"type":"Lora","lora_r":8,"lora_alpha":32,"lora_trainable_modules":"all-linear"},"staring_epoch":0,"training_offset":0,"random_seed":"null","enable_checkpoints":false,"max_steps":-1,"save_steps":0,"eval_steps":0,"warmup_steps":0,"warmup_ratio":0,"validation_split_ratio":0,"learning_rate":0.00001,"lr_scheduler":{"lr_scheduler_type":"linear","lr_scheduler_args":{}},"gpus_per_node":8,"per_device_batch_size":0,"per_device_eval_batch_size":0,"gradient_accumulation_steps":1,"train_on_inputs":"auto","max_grad_norm":1,"weight_decay":0,"wandb_

Wait for the fine-tuning job to complete.

This cell will take a while to run.


In [20]:
while True:
    clear_output(wait=True)

    try:
        job_status = requests.get(
            f"https://api.together.xyz/v1/fine-tunes/{fine_tune_id}",
            headers={
                "accept": "application/json",
                "authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
            },
        ).json()
        pprint(job_status)
        print("Status: ", job_status["status"])
        if job_status["status"] in ("completed", "failed", "cancelled"):
            break
    except Exception as e:
        print(f"Error: {e}")

    time.sleep(10)

{'TrainingFileNumLines': 0,
 'TrainingFileSize': 102170,
 'UsedModelName': '',
 'adapter_output_path': 's3://together-dev/finetune/65ef1c2c93431c0b7c097833/virajmehta/Meta-Llama-3.1-8B-Instruct-Reference-7b84fdf5/ft-c249848d-c649_adapter',
 'batch_size': 16,
 'checkpoints': [],
 'checkspoint_path': '',
 'continued_checkpoint_url': '',
 'created_at': '2025-02-25T20:02:20.047Z',
 'enable_checkpoints': False,
 'epochs_completed': 1,
 'eval_price': 0,
 'eval_steps': 0,
 'eval_token_count': 5417,
 'evals_completed': 0,
 'evals_paid_for': 0,
 'events': [{'adapter_path': '',
             'byoa_model_name': '',
             'checkpoint_path': '',
             'created_at': '2025-02-25T20:02:20.047Z',
             'eval_token_count': 0,
             'hash': '',
             'level': '',
             'message': 'Fine tune request created',
             'model_path': '',
             'object': 'fine-tune-event',
             'param_count': 0,
             'step': 0,
             'token_count': 0,

KeyboardInterrupt: 

Once the fine-tuning job is complete, you can add the fine-tuned model to your config file.


In [21]:
fine_tuned_model = job_status["model_output_name"]
model_config = {
    "models": {
        fine_tuned_model: {
            "routing": ["together"],
            "providers": {
                "together": {"type": "together", "model_name": fine_tuned_model}
            },
        }
    }
}

print(toml.dumps(model_config))

[models."virajmehta/Meta-Llama-3.1-8B-Instruct-Reference-7b84fdf5"]
routing = [ "together",]

[models."virajmehta/Meta-Llama-3.1-8B-Instruct-Reference-7b84fdf5".providers.together]
type = "together"
model_name = "virajmehta/Meta-Llama-3.1-8B-Instruct-Reference-7b84fdf5"



Finally, add a new variant to your function to use the fine-tuned model.


In [23]:
variant_config = {
    "type": "chat_completion",
    "weight": 0,
    "model": fine_tuned_model,
}

system_template = variant.get("system_template")
if system_template:
    variant_config["system_template"] = system_template

user_template = variant.get("user_template")
if user_template:
    variant_config["user_template"] = user_template

assistant_template = variant.get("assistant_template")
if assistant_template:
    variant_config["assistant_template"] = assistant_template

full_variant_config = {
    "functions": {FUNCTION_NAME: {"variants": {fine_tuned_model: variant_config}}}
}

print(toml.dumps(full_variant_config))

[functions.extract_entities.variants."virajmehta/Meta-Llama-3.1-8B-Instruct-Reference-7b84fdf5"]
type = "chat_completion"
weight = 0
model = "virajmehta/Meta-Llama-3.1-8B-Instruct-Reference-7b84fdf5"
system_template = "functions/extract_entities/initial_prompt/system_template.minijinja"



You're all set!

You can change the weight to enable a gradual rollout of the new model.

You might also add other parameters (e.g. `max_tokens`, `temperature`) to the variant section in the config file.
