<a href="https://colab.research.google.com/github/vadim-vic/Foundation-ts/blob/main/sandbox/Token_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# Small base model
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Add two new special tokens for tools
new_tokens = ["<TOOL_CALC>", "<TOOL_WEATHER>"]
tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))

# Freeze base model parameters
for param in model.parameters():
    param.requires_grad = False

# Allow training only toolken embeddings
embedding_layer = model.get_input_embeddings()
for token in new_tokens:
    token_id = tokenizer.convert_tokens_to_ids(token)
    embedding_layer.weight[token_id].requires_grad = True


In [None]:
def tool_executor(tool_name, args=None):
    if tool_name == "<TOOL_CALC>":
        # Expect args like {"a": 2, "b": 2}
        return str(args["a"] + args["b"])
    elif tool_name == "<TOOL_WEATHER>":
        # Dummy weather
        return f"The weather in {args['city']} is sunny, 25°C."
    return "?"


In [None]:
train_examples = [
    # Calculator tool
    ("What is 2 + 2?", "<TOOL_CALC> 4"),
    ("Please add 3 and 5.", "<TOOL_CALC> 8"),

    # Weather tool
    ("What’s the weather in Paris?", "<TOOL_WEATHER> The weather in Paris is sunny, 25°C."),
    ("Tell me the weather in London.", "<TOOL_WEATHER> The weather in London is sunny, 25°C."),
]


In [None]:
# --- Corrected Step 4: training loop (only toolken rows should update) ---
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  # we will still train embeddings but keep other layers frozen

# Freeze all parameters first
for param in model.parameters():
    param.requires_grad = False

# Make embedding matrix trainable (leaf Parameter)
embedding_layer = model.get_input_embeddings()
embedding_layer.weight.requires_grad = True

# Tool token ids (used to preserve only those rows' grads)
tool_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in new_tokens]
vocab_size, emb_dim = embedding_layer.weight.shape
print("Vocab size:", vocab_size, "Embedding dim:", emb_dim)

# Optimizer: optimize the whole embedding matrix, but we'll zero non-tool grads later
optimizer = torch.optim.Adam([embedding_layer.weight], lr=1e-3)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    total_loss = 0.0
    for prompt, expected in train_examples:
        # Tokenize prompt and expected separately so we can mask correctly
        prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)         # (1, Lp)
        expected_ids = tokenizer(expected, return_tensors="pt")["input_ids"].to(device)     # (1, Le)

        # Concatenate prompt + expected into one sequence
        input_ids = torch.cat([prompt_ids, expected_ids], dim=1)  # (1, Lp+Le)
        labels = input_ids.clone()                                # we predict the entire tail
        prompt_len = prompt_ids.size(1)
        # Mask prompt tokens so loss is only on the expected part
        labels[:, :prompt_len] = -100

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()

        # --- IMPORTANT: zero gradients for all embedding rows except tool tokens ---
        if embedding_layer.weight.grad is not None:
            # create boolean mask for rows to zero (True => zero)
            mask = torch.ones(vocab_size, dtype=torch.bool, device=embedding_layer.weight.grad.device)
            mask[tool_token_ids] = False
            # zero out the grads for non-tool tokens
            embedding_layer.weight.grad[mask] = 0.0

        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{num_epochs}  loss={total_loss:.4f}")


Vocab size: 50259 Embedding dim: 768
Epoch 1/50  loss=31.3559
Epoch 10/50  loss=14.7615
Epoch 20/50  loss=4.7273
Epoch 30/50  loss=3.1997
Epoch 40/50  loss=2.4941
Epoch 50/50  loss=1.9601


In [None]:
def run_with_tools(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    output_ids = model.generate(**inputs, max_length=40)
    decoded = tokenizer.decode(output_ids[0])

    print("Raw model output:", decoded)

    # Intercept tool tokens
    if "<TOOL_CALC>" in decoded:
        # naive arg extraction (demo only)
        if "2 + 2" in prompt:
            args = {"a": 2, "b": 2}
        elif "3 and 5" in prompt:
            args = {"a": 3, "b": 5}
        else:
            args = {"a": 0, "b": 0}
        result = tool_executor("<TOOL_CALC>", args)
        return decoded.replace("<TOOL_CALC>", result)

    elif "<TOOL_WEATHER>" in decoded:
        if "Paris" in prompt:
            args = {"city": "Paris"}
        elif "London" in prompt:
            args = {"city": "London"}
        else:
            args = {"city": "Unknown"}
        result = tool_executor("<TOOL_WEATHER>", args)
        return decoded.replace("<TOOL_WEATHER>", result)

    return decoded


# 🔥 Try it out
print(run_with_tools("What is 2 + 2?"))
print(run_with_tools("Tell me the weather in Paris."))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Raw model output: What is 2 + 2?<TOOL_CALC> 4 + 2?<TOOL_CALC> 5 + 2?<TOOL_CALC> 6 + 2?



















What is 2 + 2?4 4 + 2?4 5 + 2?4 6 + 2?



















Raw model output: Tell me the weather in Paris.<TOOL_WEATHER> The weather in Paris is sunny.
<TOOL_WEATHER> The weather in Paris is sunny.<TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER><TOOL_WEATHER>
Tell me the weather in Paris.The weather in Paris is sunny, 25°C. The weather in Paris is sunny.
The weather in Paris is sunny, 25°C. The weather in Paris is sunny.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris is sunny, 25°C.The weather in Paris 