## Speech Toolformer

Limit `torch._dynamo.config.recompile_limit` to improve stability.

This setup below was optimized by unsloth team for stable training and inference on Colab T4 machines.

More details here: https://docs.unsloth.ai/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune#fixes-for-gemma-3n

In [None]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install torchcodec

import torch; torch._dynamo.config.recompile_limit = 64;

In [None]:
from unsloth import FastModel


model, processor = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it",
    dtype = None,
    max_seq_length = 1024,
    load_in_4bit = True,
    full_finetuning = False,
)

==((====))==  Unsloth 2025.11.6: Fast Gemma3N patching. Transformers: 4.56.2.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Gemma-3n is an instruction fine-tuned model and expects a chat formatted input (with user, assistant and system roles). Let's check how the model sees chat input as pure text

In [None]:
messages = [
    {"role": "system", "content": "You are a helpful assistant for ASR projects."},
    {"role": "user", "content": "Briefly explain CTC loss."},
]

chat_template_messages = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
).strip() # this strip() removes '/n' symbol in the end

chat_template_messages

'<bos><start_of_turn>user\nYou are a helpful assistant for ASR projects.\n\nBriefly explain CTC loss.<end_of_turn>\n<start_of_turn>model'

Below here is a helper function for

In [None]:
from transformers import TextStreamer
from unsloth.chat_templates import get_chat_template


# Check out available chat templates and how they are applied
# https://github.com/unslothai/unsloth/blob/main/unsloth/chat_templates.py
tokenizer = get_chat_template(processor, chat_template="gemma3n")

# Helper function for inference
def streaming_inference(messages, max_new_tokens = 128):
    _ = model.generate(
        **tokenizer.apply_chat_template(
            messages,
            add_generation_prompt = True,
            tokenize = True,
            return_dict = True,
            return_tensors = "pt",
        ).to("cuda"),
        max_new_tokens = max_new_tokens,
        do_sample=False,
        streamer = TextStreamer(processor, skip_prompt = True),
    )

Here are some possible speech tool invocation pipelines:

- Audio → ASR → tool call

- Audio → transcript + tool call in one shot

- Audio → tool call directly (no explicit transcript)


Let's test how well the model can execute each of those
But first.... let's test if the model can understand textual inputs and generate tool call in the correct format.
<br></br>
**Note**: it is important the the model remains generic and can only trigger tool usage if the corresponsing user request was given

In [None]:
"""
Let's tune system prompt and test model performance on text inputs
I expect the model to output a <tool_call> block when I ask about the weather, and plain text when I ask about CTC.

What happens if instead of putting the tool instructions in the system message,
we put them into a long user message?
"""

SYSTEM_PROMPT_V2 = """You are a tool-using assistant.
You MUST respond with exactly one XML tool call without any surrounding text if the tool usage is required (user is asking about the weather), otherwise respond to user's request in plain text.
Template: <tool_call><name>weather.get_forecast</name><arguments><city>CITY</city></arguments></tool_call>. If tool usage is not required respond shortly with text
"""

test_messages = [
    {
        "role": "system",
        "content": [
            {
                "type": "text",
                "text": SYSTEM_PROMPT_V2,
            }
        ],
    },
    {
        "role": "user",
        "content": [
            # {"type": "text", "text": "What's the current temperature in Voronezh?"}
            # {"type": "text", "text": "Is it rainy in Berlin now?"}
            # {"type": "text", "text": "Is it rainy in Astana?"}
            # {"type": "text", "text": "is it foggy in Melbourne right now?"}
            {"type": "text", "text": "I am trying to understand how CTC training in ASR works. Can you briefly explain that?"}
            # {"type": "text", "text": "What is going to happen to me in 6 months if I run 5 km every single day?"}
        ]
    }
]

streaming_inference(test_messages)

CTC (Connectionist Temporal Classification) training in ASR (Automatic Speech Recognition) is a method used to train neural networks for sequence labeling tasks, particularly speech recognition. It addresses the challenge of aligning audio frames with characters or words in the transcription. 

Here's a brief explanation:

Instead of requiring strict alignment between input and output, CTC allows for variable-length alignments. It works by considering all possible alignments of the input sequence to the output sequence.  The core idea is to predict a probability distribution over the output vocabulary for each time step in the audio.  Then, it sums the probabilities of all possible alignments, effectively


Loading dataset from Google Drive. Replace the path & column names that match your datast.

Inspect one example to verify that audio is correctly loaded and all fields are present

In [None]:
from datasets import load_dataset, Audio
from pathlib import Path

# change to your path
data_path = "/content/drive/MyDrive/DLS Speech/local_datasets/weather_dataset_with_tts_v2.tsv"
audio_paths_prefix = "/content/drive/MyDrive/DLS Speech/"

ds = load_dataset(
    "csv",
    data_files=data_path,
    split="train",
    column_names=["user_text", "user_audio", "assistant_text"],  # change to your columns
    delimiter="\t",
)


def add_audio_path_prefix(example):
    p = Path(example["user_audio"])
    example["user_audio"] = str(Path(audio_paths_prefix) / p)
    return example

ds = ds.map(add_audio_path_prefix)


ds = ds.map(lambda batch: {"user_audio_path": batch["user_audio"]}, batched=True)  # let's copy audio paths to another column
ds = ds.cast_column("user_audio", Audio(sampling_rate=16000, decode=True))  # and cast audio column into decoded audio

# 90/10 split
splits = ds.train_test_split(test_size=0.1, seed=3407, shuffle=True)
train_ds = splits["train"]
test_ds  = splits["test"]

train_ds[100]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

{'user_text': 'Weather update for Phoenix right now, please.',
 'user_audio': <datasets.features._torchcodec.AudioDecoder at 0x7fd97d394a70>,
 'assistant_text': '{"tool_call": {"name": "weather.get_forecast", "arguments": {"city": "Phoenix"}}}',
 'user_audio_path': '/content/drive/MyDrive/DLS Speech/local_datasets/audio/000416.wav'}

In [None]:
from IPython.display import Audio

Audio(train_ds[100]["user_audio_path"], autoplay=False)

In [None]:
"""
Let's also test speech recognition performance of the model
In this cell you prompt the model to only transcribe speech into text w/o tool call
Where does the model struggle most (numbers, cities, dates, Russian vs English)?
"""

for sample in test_ds.select(range(5)):
    user_audio = sample["user_audio_path"]
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are an assistant that transcribes speech into text accurately.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "audio", "audio": user_audio},
                {"type": "text", "text": "Transcribe this audio"}
            ]
        }
    ]
    print(f"Ground truth transcript:\n{sample['user_text']}\n")
    print("Recognized speech:")
    streaming_inference(messages, max_new_tokens = 128)
    print("\n==================\n")

Ground truth transcript:
Is it cold in Granada right now?

Recognized speech:
Is it cold in Granada right now?<end_of_turn>


Ground truth transcript:
Give me the current weather conditions in Miami.

Recognized speech:
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j
j



Ground truth transcript:
Is it foggy in Melbourne right now?

Recognized speech:
Is it foggy in Melbourne right now?<end_of_turn>


Ground truth transcript:
Is it foggy in Geneva right now?

Recognized speech:
It's foggy in Geneva right now.<end_of_turn>


Ground truth transcript:
Is it stormy in San Diego at the moment?

Recognized speech:
Is it stormy in San Diego at the moment?<end_of_turn>




In [None]:
"""Maybe the model can do both ASR and tool-call in a single go?
Does combining ASR + tool calling in a single prompt degrade either performance?
"""

import json


SYSTEM_PROMPT_MULTITASK_V1 = """You are a tool-using assistant that can understand audio. You are given a user's request in audio format.
You MUST transcribe user's audio and respond with exactly one XML tool call if the tool usage is required (user is asking about the weather), otherwise respond to user's request in plain text.
Template: <tool_call><name>weather.get_forecast</name><arguments><city>CITY</city></arguments></tool_call>. If tool usage is not required respond shortly with text
"""


for sample in test_ds.select(range(5)):
    user_audio = sample["user_audio"]
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": SYSTEM_PROMPT_MULTITASK_V1
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "audio", "audio": user_audio["array"]},
                {"type": "text", "text": "Transcribe this audio and respond to the user's request accordignly"}
            ]
        }
    ]
    print(f"Ground truth transcript:\n{sample['user_text']}\n")
    print("Model response:")
    streaming_inference(messages, max_new_tokens = 256)
    print("\n==================\n")

Ground truth transcript:
Is it cold in Granada right now?

Model response:
<tool_call><name>weather.get_forecast</name><arguments><city>Granada</city></arguments></tool_call><end_of_turn>


Ground truth transcript:
Give me the current weather conditions in Miami.

Model response:
<tool_call><name>weather.get_forecast</name><arguments><city>Miami</city></arguments></tool_call><end_of_turn>


Ground truth transcript:
Is it foggy in Melbourne right now?

Model response:
<tool_call><name>weather.get_forecast</name><arguments><city>Melbourne</city></arguments></tool_call><end_of_turn>


Ground truth transcript:
Is it foggy in Geneva right now?

Model response:
<tool_call><name>weather.get_forecast</name><arguments><city>Geneva</city></arguments></tool_call><end_of_turn>


Ground truth transcript:
Is it stormy in San Diego at the moment?

Model response:
<tool_call><name>weather.get_forecast</name><arguments><city>San Diego</city></arguments></tool_call><end_of_turn>




In [None]:
import numpy as np


def inference(messages: list, user_audio: np.array = None, max_new_tokens: int = 128):
    # Build prompt text once; no tokenization here
    prompt_text = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=False
    )
    # Pack text + audio (audio must be waveform, float32, 16kHz)
    inputs = processor(
        text=[prompt_text],
        audio=[user_audio],
        return_tensors="pt",
        padding=False,  # <= I noticed that padding might break ASR, don't use it here
    ).to("cuda")

    gen_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False)
    out = model.generate(**inputs, **gen_kwargs)

    input_len = inputs["input_ids"].shape[1]
    generated_ids = out[0, input_len:]
    text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)

    return text


def inference_text(messages: list, max_new_tokens=128):
    inputs = processor.tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to("cuda")

    gen_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False)

    out = model.generate(**inputs, **gen_kwargs)
    input_len = inputs["input_ids"].shape[1]
    generated_ids = out[0, input_len:]
    return processor.tokenizer.decode(generated_ids, skip_special_tokens=True)

In [None]:
"""Finally, let's do 2 sequential model calls Audio -> ASR transcript; ASR transcipt -> tool-call
"""

for sample in test_ds.select(range(5)):
    user_audio = sample["user_audio"]["array"].astype("float32")
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are an assistant that transcribes speech into text accurately.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "audio", "audio": user_audio},
                {"type": "text", "text": "Transcribe this audio"}
            ]
        }
    ]
    print(f"Ground truth transcript:\n{sample['user_text']}\n")

    # 1st run -- Transcribe audio into text
    recognized_text = inference(messages, user_audio, max_new_tokens=128)
    print(f"Recognized speech:\n{recognized_text}\n")

    # 2d run -- Recognized text -> tool-call
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": SYSTEM_PROMPT_V2
                },
                {"type": "text", "text": recognized_text}
            ]
        }
    ]
    model_response = inference_text(messages)
    print(f"Model response:\n{model_response}")
    print("\n==================\n")

Ground truth transcript:
Is it cold in Granada right now?

Recognized speech:
is it cold in Granada right now?

Model response:
It's currently 15 degrees Celsius in Granada.



Ground truth transcript:
Give me the current weather conditions in Miami.

Recognized speech:
j

Model response:
Okay, I understand. I will respond with exactly one XML tool call if the user's request requires it, otherwise I will respond in plain text.



Ground truth transcript:
Is it foggy in Melbourne right now?

Recognized speech:
is it foggy in Melbourne right now?

Model response:
It is foggy in Melbourne right now.


Ground truth transcript:
Is it foggy in Geneva right now?

Recognized speech:
It's foggy in Geneva right now.

Model response:
It's foggy in Geneva right now.


Ground truth transcript:
Is it stormy in San Diego at the moment?

Recognized speech:
Is it stormy in San Diego at the moment?

Model response:
Is it stormy in San Diego at the moment?
<tool_call><name>weather.get_forecast</name><arg