# Device Actions

Generates datasets for performing actions on devices in a synthetic home. This will generate a list
of text / voice commands that you can perform in a home. These are not labeled with the outcome
which is generated in a later step.

In [3]:
import openai
import google.generativeai as genai

from home_assistant_datasets import secrets
from home_assistant_datasets.secrets import get_secret
from home_assistant_datasets import model_client

secrets.DEFAULT_SECRETS_FILE = "../secrets.yaml"

# MODEL_ID = "gpt-3.5-turbo-0125"
# openai = openai.OpenAI(api_key=secrets.get_secret("openai_api_key"))
# model = model_client.ModelClient(openai, MODEL_ID)

# Gemini flash is higher quality and cheaper model than the GPT alternatives.
MODEL_ID = "gemini-1.5-flash"
genai.configure(api_key=secrets.get_secret("google_api_key"))
model = model_client.GoogleClient(MODEL_ID)

# Generate few-shot examples

Read the seed data used as a few-shot example.

In [69]:
import pathlib
import yaml
from synthetic_home import device_types


DATASET_DIR = pathlib.Path("../datasets/")
DEVICES_DIR = DATASET_DIR / "devices-v3"
SEEDS_DIR = pathlib.Path("./seeds")
SEED_DEVICE_ACTIONS_FILE = SEEDS_DIR / "device-actions.yaml"
SEED_DEVICE_ACTIONS_CAPABILITIES_FILE = SEEDS_DIR / "device-actions-capabilities.yaml"

with open(SEED_DEVICE_ACTIONS_FILE) as f:
    seed_device_actions = list(yaml.load_all(f.read(), Loader=yaml.Loader))

# This is a fixed list of capabilities that any particular synthetic home device type support
with open(SEED_DEVICE_ACTIONS_CAPABILITIES_FILE) as f:
    capabilities = {
        cap["device_type"]: cap["actions"]
        for cap in yaml.load(f.read(), Loader=yaml.Loader)
    }

seed_devices_prompt = "".join(yaml.dump(content, sort_keys=False, explicit_start=True) for content in seed_device_actions)
print(seed_devices_prompt)

registry = device_types.load_device_type_registry()
# Find any devices missing explicit action capabilities definitions
missing_devices = [
    {"device_type": dt, "actions": []}
    for dt in registry.device_types
    if dt not in capabilities
]
if missing_devices:
    print(yaml.dump(missing_devices, sort_keys=False))


---
home: mountain-cabin-us
device:
  name: Kitchen Overhead Light
  area: Kitchen
  device_type: light
  device_info:
    model: Smart LED Bulb
    manufacturer: Philips
    sw_version: 1.2.3
capabilities:
- Turn on
- Turn off
---
actions:
- action: Turn on
  sentences:
  - Please turn on the kitchen overhead light
  - Turn on the kitchen light
  - Kitchen light on
- action: Turn off
  sentences:
  - Please turn off the kitchen overhead light
  - Turn off the kitchen light
  - Kitchen light off



In [70]:
SUMMARY_PROMPT = f"""
You are an expert Smart Home agent who can evaluate the performance of a smart
home, and perform useful actions on behalf of a user.

A device in Home Assistant represents a physical or virtual object, represented
by different entities. A device has attributes for its configuration and state,
for example a thermostat may have a mode attribute, or target or current temperature
attributes.

You generate a simple evaluation dataset for home data. The input dataset
contains the home, description information like location, areas, and devices.
The output data are actions a user may ask to take on a devie.

This is the input yaml document and the output actions yaml document:

{seed_devices_prompt}

Generate a few sentences to control the device. Answer in yaml plain text and do not answer with markdown.
"""

In [71]:
import itertools
import random
from tqdm.auto import tqdm
import shutil
import slugify

homes = []
for path in DEVICES_DIR.glob("*.yaml"):
    with path.open("r") as f:
        content = f.read()
    home_id = path.name.split(".")[0]  # Strip the .yaml extension
    home_data = yaml.load(content, Loader=yaml.Loader)
    homes.append((home_id, home_data))

tasks = []
no_actions = 0
task_types = {}
for home_id, home in homes:
    home_template = {
            "home": home_id,
            "location": home["location"],
            "type": home["type"],
    }
    for area, devices in home["devices"].items():
        for device in devices or []:
            device_type = device["device_type"]
            if not (device_caps := capabilities.get(device_type)):
                # No supported actions
                no_actions += 1
                continue
            task_types[device_type] = task_types.get(device_type, 0) + 1
            device_info = {
                    **home_template,
                    "device": {
                        **device,
                        "area": area,
                    },
                    "capabilities": device_caps,
            }
            tasks.append(device_info)
print((len(homes), len(tasks), no_actions))
print(yaml.dump(task_types))

(40, 480, 91)
exhaust-fan: 11
fan-oscilating: 2
garage-door: 6
heat-pump: 3
hvac: 30
light: 203
light-dimmable: 85
smart-blinds: 1
smart-lock: 5
smart-plug: 34
smart-speaker: 49
smart-sprinkler: 17
smart-tv: 17
switch: 8
vacuum: 3
water-valve: 6



In [72]:
random.shuffle(tasks)
print(yaml.dump(tasks[0], sort_keys=False, explicit_start=True))

---
home: home4-us
location: Coastal town in Florida
type: Beach house
device:
  name: Kids Bathroom Light
  device_type: light
  device_info:
    model: Smart LED Bulb
    manufacturer: Philips
    sw_version: 1.2.3
  area: Kids Bathroom
capabilities:
- Turn on
- Turn off



# Generate Output

In [75]:
import slugify

# Total number of records to generate
N_DATAPOINTS = -1

DEVICE_ACTIONS_OUTPUT_DIR = DATASET_DIR / "device-actions-v2"

# Wipe existing summaries
shutil.rmtree(DEVICE_ACTIONS_OUTPUT_DIR, ignore_errors=True)
DEVICE_ACTIONS_OUTPUT_DIR.mkdir(exist_ok=True)

random.shuffle(tasks)
if N_DATAPOINTS > 0 and len(tasks) > N_DATAPOINTS:
    tasks = tasks[:N_DATAPOINTS]

skipped = 0
with tqdm(total=len(tasks)) as pbar:
    for task in tasks:
        home_id = slugify.slugify(task["home"], separator="-")
        task_id = "_".join([
              slugify.slugify(task["device"]["area"], separator="-"),
              slugify.slugify(task["device"]["name"], separator="-"),
        ])
        home_dir = DEVICE_ACTIONS_OUTPUT_DIR / home_id
        if not home_dir.exists():
            home_dir.mkdir()
        with open(DEVICE_ACTIONS_OUTPUT_DIR / home_id / f"{task_id}.yaml", "w") as action_output:
            task_yaml = yaml.dump(task, sort_keys=False, explicit_start=True)
            response_obj = None
            for i in range(3):
                response = model.complete(SUMMARY_PROMPT, task_yaml)
                try:
                    response_obj = yaml.safe_load(response)
                except yaml.YAMLError as err:
                    print(err)
                    skipped += 1
                    continue
            if response_obj is not None:
                updated_task = task.copy()
                updated_task.update({"actions": response_obj})
                action_output.write(yaml.dump(updated_task, explicit_start=True, sort_keys=False))
            pbar.set_description(f"Skipped {skipped}")
            pbar.update(1)

  0%|          | 0/480 [00:00<?, ?it/s]

Skipped 0:  10%|█         | 48/480 [02:13<15:24,  2.14s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 20, column 1:
    ---
    ^


Skipped 1:  14%|█▍        | 69/480 [03:14<19:18,  2.82s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 13, column 1:
    ---
    ^


Skipped 2:  21%|██        | 99/480 [04:35<16:23,  2.58s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 13, column 1:
    ---
    ^


Skipped 4:  33%|███▎      | 158/480 [07:12<13:14,  2.47s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 13, column 1:
    ---
    ^


Skipped 4:  39%|███▉      | 186/480 [08:47<12:00,  2.45s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 19, column 1:
    ---
    ^


Skipped 6:  45%|████▌     | 217/480 [10:06<11:23,  2.60s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 22, column 1:
    ---
    ^


Skipped 6:  57%|█████▋    | 272/480 [12:30<07:35,  2.19s/it]

while parsing a block mapping
  in "<unicode string>", line 13, column 3:
    - action: Set position
      ^
expected <block end>, but found '<scalar>'
  in "<unicode string>", line 17, column 16:
      - [position] the shower
                   ^


Skipped 8:  66%|██████▋   | 318/480 [14:27<07:41,  2.85s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 15, column 1:
    ---
    ^


Skipped 9:  80%|████████  | 386/480 [17:46<04:14,  2.70s/it]

expected a single document in the stream
  in "<unicode string>", line 2, column 1:
    actions:
    ^
but found another document
  in "<unicode string>", line 21, column 1:
    ---
    ^


Skipped 9: 100%|██████████| 480/480 [22:00<00:00,  2.75s/it]


## Device Actions Fixtures

Generate test fixtures from the device actions datasets. This will create the
home inventory to power the device actions data collections steps.

In [22]:
import dataclasses
import pathlib
from synthetic_home import synthetic_home
import shutil
import yaml

DATASET_DIR = pathlib.Path("../datasets/")
DEVICES_DIR = DATASET_DIR / "devices-v3"
DEVICE_ACTIONS_DIR = DATASET_DIR / "device-actions-v2"
DEVICE_ACTIONS_FIXTURES_DIR = DATASET_DIR / "device-actions-v2-fixtures"

shutil.rmtree(DEVICE_ACTIONS_FIXTURES_DIR, ignore_errors=True)
DEVICE_ACTIONS_FIXTURES_DIR.mkdir(exist_ok=True)

homes_count = 0
devices_count = 0
sentences_count = 0
device_type_sentences = {}

@dataclasses.dataclass
class DeviceTasks:
   device: str
   area: str | None
   device_id: str | None
   entity_id: str | None
   sentences: list[str]


for devices_file in DEVICES_DIR.glob("*.yaml"):
   home_id = devices_file.name.split(".")[0]
   home = synthetic_home.load_synthetic_home(devices_file)

   home_dir = DEVICE_ACTIONS_FIXTURES_DIR / home_id
   home_dir.mkdir(exist_ok=True)

   inventory = synthetic_home.build_inventory(home)

   fixtures = home_dir / "_fixtures.yaml"
   fixtures.write_text(inventory.to_yaml())

   homes_count += 1
   category_tasks = {}
   for actions_file in (DEVICE_ACTIONS_DIR / home_id).glob("*.yaml"):
      devices_count += 1
      device_actions = yaml.load(actions_file.read_text(), Loader=yaml.CSafeLoader)
      device = device_actions["device"]
      category = device["device_type"]
      if category not in category_tasks:
         category_tasks[category] = []

      device_id: str | None = None
      for inv_device in inventory.devices:
         if inv_device.name.lower() == device["name"].lower():
            device_id = inv_device.id
            break
      assert device_id
      entity_id: str | None = None
      for inv_entity in inventory.entities:
         if inv_entity.name.lower() == device["name"].lower():
            if inv_entity.device != device_id:
               raise ValueError(f"Wrong device: {device}")
            entity_id = inv_entity.id
            break
      assert entity_id
      if entity_id.startswith("sensor") or entity_id.startswith("binary_sensor"):
         raise ValueError(f"Matched entity that does not support control {device}")

      actions = device_actions["actions"]
      if "actions" in actions:
         actions = actions["actions"]
      for action_data in actions:
         sentences = action_data["sentences"]
         category_tasks[category].append(
            DeviceTasks(device=device["name"], area=device["area"], device_id=device_id, entity_id=entity_id, sentences=sentences)
         )
      sentences_count += len(sentences)
      device_type_sentences[category] = device_type_sentences.get(category, 0) + len(sentences)

   for category, tasks in category_tasks.items():
      data = {
         "category": category,
         "tests": [
            dataclasses.asdict(task)
            for task in tasks
         ]
      }
      category_file = home_dir / f"{category}.yaml"
      category_file.write_text(yaml.dump(data, sort_keys=False, explicit_start=True))

print(homes_count, devices_count, sentences_count)
print(yaml.dump(device_type_sentences, explicit_start=True, sort_keys=False))


40 479 1908
---
light: 619
smart-plug: 109
light-dimmable: 636
hvac: 103
exhaust-fan: 33
smart-speaker: 141
switch: 25
smart-sprinkler: 56
water-valve: 28
garage-door: 18
smart-tv: 95
heat-pump: 10
smart-lock: 15
smart-blinds: 3
vacuum: 11
fan-oscilating: 6



# Assist pipeline teacher

Run the assist pipeline data collection step to generate tool calls against the fixtures

```bash
$ source venv/bin/activate
(venv) $ home-assistant-datasets assist collect --dataset ./datasets/device-actions-v2-fixtures/ --model_output_dir=./datasets/device-actions-v2-collect/ --models=assistant
```

# Save successful Assistant results

This saves all the successful results from the assistant pipeline.

In [2]:
import pathlib
import yaml
import slugify
import shutil

DATASET_DIR = pathlib.Path("../datasets/")
FIXTTURES_DIR = DATASET_DIR / "device-actions-v2-fixtures"
COLLECT_DIR = DATASET_DIR / "device-actions-v2-collect"
ASSIST_TEACHER_DIR = COLLECT_DIR / "assistant"

TRAIN_DIR = COLLECT_DIR / "train"
shutil.rmtree(TRAIN_DIR, ignore_errors=True)
TRAIN_DIR.mkdir(exist_ok=True)

# These can not be used in the training set since they are used for eval
EVAL_HOME_IDS = {
  "dom1-pl",
  "home1-us",
  "home2-ru",
  "home5-cn",
  "home7-dk",
}


success = {}
total = {}
total_sentences = 0

for path in FIXTTURES_DIR.glob("**/*.yaml"):
    if path.name == "_fixtures.yaml":
        continue
    home_id = path.parent.name
    category = path.name.split(".")[0]

    if home_id in EVAL_HOME_IDS:
        continue

    fixture_record = yaml.load(path.read_text(), Loader=yaml.CSafeLoader)
    matched_tests = []

    file_prefix = "_".join([
        slugify.slugify(home_id, separator="_"),
        slugify.slugify(category, separator="_"),
    ])
    assist_outputs = ASSIST_TEACHER_DIR.glob(f"{file_prefix}*.yaml")
    for filename in assist_outputs:
        total[category] = total.get(category, 0) + 1
        record = yaml.load(filename.read_text(), Loader=yaml.Loader)
        input_text = record["task"]["input_text"]
        if record["response"].startswith("Sorry"):
            continue
        success[category] = success.get(category, 0) + 1

        context = record["context"]
        conversation_trace = context["conversation_trace"]
        if len(conversation_trace) < 2:
            continue
        if conversation_trace[1]["event_type"] != "tool_call":
            continue
        if not (tool_call := conversation_trace[1].get("data")):
            continue

        for record in fixture_record["tests"]:
            sentences = list(record["sentences"])
            for sentence in sentences:
                if sentence == input_text:
                    total_sentences += 1
                    matched_tests.append({
                        **(record.copy()),
                        "sentences": [sentence],
                        "function": {
                            "name": tool_call["intent_name"],
                            "arguments": tool_call["slots"],
                        }
                    })
                    break

    if not matched_tests:
        continue
    output_record = fixture_record.copy()
    output_record["tests"] = matched_tests
    if output_record["tests"]:
        (TRAIN_DIR / home_id).mkdir(exist_ok=True)
        out_file = TRAIN_DIR / home_id / f"{category}.yaml"
        out_file.write_text(yaml.dump(output_record, sort_keys=False, explicit_start=True))


In [3]:
print(f"Total sentences: {total_sentences}")
for category in total:
    s = success.get(category, 0)
    t = total[category]
    print(f"{category} - {s} - {100*(s / t):0.2f}% - {100*(s / total_sentences):0.2f}%")

Total sentences: 2332
smart-tv - 26 - 13.54% - 1.11%
smart-speaker - 26 - 3.83% - 1.11%
light-dimmable - 784 - 67.64% - 33.62%
water-valve - 13 - 20.31% - 0.56%
smart-plug - 140 - 64.22% - 6.00%
light - 1824 - 77.78% - 78.22%
exhaust-fan - 66 - 100.00% - 2.83%
switch - 28 - 63.64% - 1.20%
smart-lock - 16 - 53.33% - 0.69%
fan-oscilating - 11 - 91.67% - 0.47%
smart-sprinkler - 20 - 17.86% - 0.86%
hvac - 100 - 33.33% - 4.29%
garage-door - 0 - 0.00% - 0.00%
vacuum - 2 - 9.09% - 0.09%
heat-pump - 8 - 42.11% - 0.34%
smart-blinds - 4 - 66.67% - 0.17%


# Cloud LLM Teacher

Scrape cloud responses with:

```
$ home-assistant-datasets assist collect --dataset ./datasets/device-actions-v2-fixtures/ --model_output_dir=./datasets/device-actions-v2-collect/ --models=gemini-1.5-flash
```

Then convert into lower level system messages below.

In [1]:
import pathlib
import yaml
import shutil
import itertools
import math
import json
import tqdm
import random
from typing import Any
from home_assistant_datasets.tokenizer import conversation, chat_template

DATASET_DIR = pathlib.Path("../datasets/")
COLLECT_DIR = DATASET_DIR / "device-actions-v2-collect"
TEACHER_MODEL_DIR = COLLECT_DIR / "gemini-1.5-flash"


def create_conversation(record: dict[str, Any]) -> conversation.ConversationRecord:
    conversation_trace = record["context"]["conversation_trace"]

    input_detail = next(filter(lambda x: x["event_type"] == "async_process", conversation_trace), None)
    input_text = input_detail["data"]["text"]

    agent_detail = next(filter(lambda x: x["event_type"] == "agent_detail", conversation_trace), None)
    prompt = agent_detail["data"]["prompt"]
    prompt = "\n".join(prompt.split("\n")[1:])  # Strip "Current time is..."

    tool_call_trace = next(filter(lambda x: x["event_type"] == "tool_call", conversation_trace), None)
    tool_calls: dict[str, str] | None = None
    content: str | None = None
    if tool_call_trace:
        tool_calls = [{
            "name": tool_call_trace["data"]["tool_name"],
            "arguments": tool_call_trace["data"]["tool_args"],
        }]
    else:
        content = record["response"]
    message = {
        "instructions": prompt,
        "tools": agent_detail["data"]["tools"],
        "input": input_text,
        "output": content or "",
        "tool_calls": tool_calls or None,
    }
    return conversation.ConversationRecord.from_dict(message)

teacher_files = list(TEACHER_MODEL_DIR.glob("*.yaml"))
random.shuffle(teacher_files)

def chunk_into_n(lst: list[str], n: int) -> list[list[str]]:
  size = math.ceil(len(lst) / n)
  return list(
    map(lambda x: lst[x * size:x * size + size],
    list(range(n)))
  )

NUM_SHARDS = 10
shards = list(chunk_into_n(teacher_files, NUM_SHARDS))
train_filenames = itertools.chain.from_iterable(shards[:-1])
test_filenames = itertools.chain.from_iterable(shards[-1:])

TOKENIZER_DIR = pathlib.Path("../home_assistant_datasets/tokenizer")
TOKENIZER_CONFIG_JSON = "tokenizer_config.json"
LLAMA3_TOKENIZER = TOKENIZER_DIR / "llama3" / TOKENIZER_CONFIG_JSON

CONVERSATION = "assist-llm-function-calling"
CONVERSATION_DIR = COLLECT_DIR / CONVERSATION
MESSAGES = "assist-llm-function-calling-messages"
MESSAGES_DIR = COLLECT_DIR / MESSAGES
ASSIST_CHAT = "assist-llm-function-calling-llama3-chat"
ASSIST_CHAT_DIR = COLLECT_DIR / ASSIST_CHAT

shutil.rmtree(CONVERSATION_DIR, ignore_errors=True)
CONVERSATION_DIR.mkdir(exist_ok=True)
shutil.rmtree(MESSAGES_DIR, ignore_errors=True)
MESSAGES_DIR.mkdir(exist_ok=True)
shutil.rmtree(ASSIST_CHAT_DIR, ignore_errors=True)
ASSIST_CHAT_DIR.mkdir(exist_ok=True)


for split, filenames in [("train", list(train_filenames)), ("test", list(test_filenames))]:
    conversation_dir = CONVERSATION_DIR / split
    conversation_dir.mkdir(exist_ok=True)
    messages_dir = MESSAGES_DIR / split
    messages_dir.mkdir(exist_ok=True)
    assist_chat_dir = ASSIST_CHAT_DIR / split
    assist_chat_dir.mkdir(exist_ok=True)

    conversation_file = conversation_dir / "conversation.jsonl"
    messages_file = messages_dir / "messages.jsonl"
    assist_chat_file = assist_chat_dir / "chat.jsonl"

    with conversation_file.open("w") as conversation_fd, messages_file.open("w") as messages_fd, assist_chat_file.open("w") as assist_chat_fd:
        for filename in tqdm.tqdm(filenames, desc=split):
            record = yaml.load(filename.read_text(), Loader=yaml.Loader)
            conversation_record = create_conversation(record)

            conversation_fd.write(conversation_record.to_json())
            conversation_fd.write("\n")

            messages_fd.write(conversation_record.to_messages_jsonl())
            messages_fd.write("\n")

            text = chat_template.build_prompt(
                messages=conversation_record.to_messages(),
                tools=conversation_record.tools,
                add_generation_prompt=True,
                tokenizer_config=LLAMA3_TOKENIZER,
            )
            json.dump({"text": text}, fp=assist_chat_fd)
            assist_chat_fd.write("\n")


train: 100%|██████████| 2124/2124 [01:31<00:00, 23.32it/s]
test: 100%|██████████| 233/233 [00:09<00:00, 24.66it/s]


In [4]:
import datasets
from home_assistant_datasets import secrets

REPO_MAP = {
    "assist-llm-function-calling": CONVERSATION_DIR,
    "assist-llm-function-calling-messages": MESSAGES_DIR,
    "assist-llm-function-calling-llama3-chat": ASSIST_CHAT_DIR,
}

for repo_id, path in REPO_MAP.items():
    ds = datasets.load_dataset(str(path))
    ds.push_to_hub(f"allenporter/{repo_id}", token=secrets.get_secret("huggingface_token"))

Creating parquet from Arrow format: 100%|██████████| 3/3 [00:00<00:00, 91.05ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 211.01ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
Generating train split: 2124 examples [00:00, 65670.30 examples/s]
Generating test split: 233 examples [00:00, 57898.74 examples/s]
Creating parquet from Arrow format: 100%|██████████| 3/3 [00:00<00:00, 147.55ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 172.88ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Generating train split: 2124 examples [00:00, 36691.98 examples/s]
Generating test split: 233 examples [00:00, 38752.99 examples/s]
Creating parquet from Arrow format: 100%|██████████| 3/3 [00:00<00:00, 75.57ba/s]
Uploading the datas

In [5]:
from huggingface_hub import DatasetCardData, DatasetCard

readme = COLLECT_DIR / "README.md"

for repo_id in REPO_MAP:
    card = DatasetCard(content=readme.read_text())
    card.push_to_hub(f"allenporter/{repo_id}", token=secrets.get_secret("huggingface_token"))

# Verify Tokenizer

This is verifying we can apply the chat template to the saved messages.

In [1]:
from transformers import PreTrainedTokenizerFast
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit")

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import datasets

train_ds = datasets.load_dataset("allenporter/assist-llm-function-calling-messages", split="train")


In [20]:
import json
from typing import Any

def formatting_prompts_func(example: dict[str, Any]):
    messages = json.loads(example["messages"])
    tools = json.loads(example["tools"])
    input_tokens = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        tools=tools,
        # tokenize=False,
    )
    # Must add EOS_TOKEN, otherwise your generation will go on forever!
    text = input_tokens + tokenizer.eos_token
    return { "text" : text }

train_ds_update = train_ds.map(formatting_prompts_func)


Map: 100%|██████████| 2124/2124 [00:01<00:00, 1282.15 examples/s]


In [28]:
my_iter = iter(train_ds_update)
r = next(my_iter)
print(r["text"])

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
When controlling Home Assistant always call the intent tools. Use HassTurnOn to lock and HassTurnOff to unlock a lock. When controlling a device, prefer passing just name and domain. When controlling an area, prefer passing just area name and domain.
When a user asks to turn on all devices of a specific type, ask user to specify an area, unless there is only one device of that type.
This device is not able to start timers.
An overview of the areas and the devices in this smart home:
- names: Thermostat
  domain: climate
  state: 'off'
  areas: Master Bedroom
  attributes:
    current_temperature: '22'
- names: Guest House Thermostat
  domain: climate
  state: 'off'
  areas: Guest House
  attributes:
