In [2]:
from datasets import load_dataset
ds = load_dataset("PolyAI/banking77")
# ds["train"][0] -> {'text': '...', 'label': int}
label_names = ds["train"].features["label"].names
num_labels = len(label_names)  # 77

Downloading data:   0%|          | 0.00/839k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/240k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3080 [00:00<?, ? examples/s]

In [3]:
# Split training data into train and validation sets
train_data = ds["train"]
print(f"Original training set size: {len(train_data)}")

# Split into 80% train, 20% validation using train_test_split
train_val_split = train_data.train_test_split(test_size=0.2, seed=42)
train_split = train_val_split["train"]
val_split = train_val_split["test"]

print(f"Train split size: {len(train_split)}")
print(f"Validation split size: {len(val_split)}")
print(f"\nSample from train split:")
print(train_split[0])

# Prepare test set
test_data = ds["test"]
test_texts = [item["text"] for item in test_data]
test_labels = [item["label"] for item in test_data]

print(f"Test set size: {len(test_texts)}")
print(f"Sample text: {test_texts[0]}")
print(f"Sample label: {label_names[test_labels[0]]}")


Original training set size: 10003
Train split size: 8002
Validation split size: 2001

Sample from train split:
{'text': "If I bought something I didn't like, can I get a refund?", 'label': 52}
Test set size: 3080
Sample text: How do I locate my card?
Sample label: card_arrival


In [6]:
TEST_SIZE = 3080

# Create system prompt

In [4]:
# Extract few-shot examples from training/validation data (NOT from test set)
# This ensures we use real examples from the training distribution
def get_examples_for_labels(dataset_split, label_names, label_name_list, num_examples=3):
    """Get real examples from the dataset for specified labels.
    
    Args:
        dataset_split: The dataset split to search (train_split or val_split)
        label_names: List of all label names
        label_name_list: List of label names to find examples for
        num_examples: Number of examples to return per label
    
    Returns:
        Dict mapping label_name -> list of example texts
    """
    examples = {}
    for label_name in label_name_list:
        try:
            label_id = label_names.index(label_name)
            # Filter dataset for this label
            label_examples = [item for item in dataset_split if item["label"] == label_id]
            # Take up to num_examples
            examples[label_name] = [item["text"] for item in label_examples[:num_examples]]
        except ValueError:
            examples[label_name] = []
    return examples

# Get examples from training data (we can also use validation, but training has more examples)
few_shot_data = get_examples_for_labels(
    train_split, 
    label_names,
    ["card_arrival", "card_delivery_estimate", "card_linking", "activate_my_card", "lost_or_stolen_card", "getting_virtual_card"],
    num_examples=3
)

print("Few-shot examples extracted from training data:")
for label, examples in few_shot_data.items():
    print(f"  {label}: {len(examples)} examples")
    if examples:
        print(f"    - {examples[0][:80]}...")


Few-shot examples extracted from training data:
  card_arrival: 3 examples
    - Could you send me and up date on the arrival of my card?...
  card_delivery_estimate: 3 examples
    - can you express my card to me?...
  card_linking: 3 examples
    - Okay, I found my card, can I put it back in the app?...
  activate_my_card: 3 examples
    - What do I need to do for the card activation?...
  lost_or_stolen_card: 3 examples
    - I left my card at a restaurant and now its missing....
  getting_virtual_card: 3 examples
    - Where do I have access to a virtual card?...


In [5]:
# Create a system prompt for intent classification using numeric IDs
# Using IDs is more reliable and easier to parse than label names
from typing import Optional

# Create the ID mapping string
id_mapping_lines = [f"{i}: {label}" for i, label in enumerate(label_names)]
id_mapping = "\n".join(id_mapping_lines)

# Find IDs for few-shot examples
def get_label_id(label_name: str) -> Optional[int]:
    """Get the numeric ID for a label name."""
    try:
        return label_names.index(label_name)
    except ValueError:
        return None

# Get IDs for common confusing pairs
card_arrival_id = get_label_id("card_arrival")
card_delivery_estimate_id = get_label_id("card_delivery_estimate")
card_linking_id = get_label_id("card_linking")
activate_my_card_id = get_label_id("activate_my_card")
lost_or_stolen_card_id = get_label_id("lost_or_stolen_card")
getting_virtual_card_id = get_label_id("getting_virtual_card")

# Create few-shot examples using REAL examples from training data (not hardcoded)
few_shot_examples = ""
if all(id is not None for id in [card_arrival_id, card_delivery_estimate_id, card_linking_id, activate_my_card_id, lost_or_stolen_card_id, getting_virtual_card_id]):
    # Use the examples extracted from training data in the previous cell
    card_arrival_examples = few_shot_data.get("card_arrival", [])
    card_delivery_estimate_examples = few_shot_data.get("card_delivery_estimate", [])
    card_linking_examples = few_shot_data.get("card_linking", [])
    activate_my_card_examples = few_shot_data.get("activate_my_card", [])
    lost_or_stolen_card_examples = few_shot_data.get("lost_or_stolen_card", [])
    getting_virtual_card_examples = few_shot_data.get("getting_virtual_card", [])
    
    # Get additional examples for other confusing pairs identified in error analysis
    additional_labels = [
        "pin_blocked", "change_pin", "pending_cash_withdrawal", "declined_cash_withdrawal",
        "cash_withdrawal_not_recognised", "verify_my_identity", "why_verify_identity",
        "unable_to_verify_identity", "card_payment_wrong_exchange_rate",
        "wrong_exchange_rate_for_cash_withdrawal", "exchange_rate", "extra_charge_on_statement",
        "card_payment_fee_charged"
    ]
    additional_few_shot_data = get_examples_for_labels(train_split, label_names, additional_labels, num_examples=3)
    
    # Get IDs for additional confusing pairs
    pin_blocked_id = get_label_id("pin_blocked")
    change_pin_id = get_label_id("change_pin")
    pending_cash_withdrawal_id = get_label_id("pending_cash_withdrawal")
    declined_cash_withdrawal_id = get_label_id("declined_cash_withdrawal")
    cash_withdrawal_not_recognised_id = get_label_id("cash_withdrawal_not_recognised")
    verify_my_identity_id = get_label_id("verify_my_identity")
    why_verify_identity_id = get_label_id("why_verify_identity")
    unable_to_verify_identity_id = get_label_id("unable_to_verify_identity")
    card_payment_wrong_exchange_rate_id = get_label_id("card_payment_wrong_exchange_rate")
    wrong_exchange_rate_for_cash_withdrawal_id = get_label_id("wrong_exchange_rate_for_cash_withdrawal")
    exchange_rate_id = get_label_id("exchange_rate")
    extra_charge_on_statement_id = get_label_id("extra_charge_on_statement")
    card_payment_fee_charged_id = get_label_id("card_payment_fee_charged")
    
    # Build few-shot examples section using REAL examples from training data
    examples_lines = ["EXAMPLES TO HELP DISTINGUISH SIMILAR INTENTS:", ""]
    
    # 1. card_arrival vs card_delivery_estimate (CRITICAL - many errors here)
    # Key distinction: card_arrival = asking about YOUR specific card that hasn't arrived
    #                  card_delivery_estimate = asking about general delivery timeframes
    if card_arrival_examples and card_delivery_estimate_examples:
        examples_lines.append(f"1. card_arrival (ID {card_arrival_id}) vs card_delivery_estimate (ID {card_delivery_estimate_id}):")
        examples_lines.append(f"   card_arrival = asking about YOUR specific card that hasn't arrived yet (tracking, status)")
        examples_lines.append(f"   card_delivery_estimate = asking about general delivery timeframes/how long it takes")
        for ex in card_arrival_examples[:3]:  # Use 3 examples to emphasize
            examples_lines.append(f'   - Query: "{ex}" → {card_arrival_id}')
        for ex in card_delivery_estimate_examples[:3]:  # Use 3 examples
            examples_lines.append(f'   - Query: "{ex}" → {card_delivery_estimate_id}')
        examples_lines.append("")
    
    # 2. card_linking vs activate_my_card vs lost_or_stolen_card
    if card_linking_examples and activate_my_card_examples and lost_or_stolen_card_examples:
        examples_lines.append(f"2. card_linking (ID {card_linking_id}) vs activate_my_card (ID {activate_my_card_id}) vs lost_or_stolen_card (ID {lost_or_stolen_card_id}):")
        examples_lines.append(f"   card_linking = reconnecting a card you found/retrieved")
        examples_lines.append(f"   activate_my_card = activating a NEW card for first time")
        examples_lines.append(f"   lost_or_stolen_card = reporting a card as lost/stolen")
        for ex in card_linking_examples[:2]:
            examples_lines.append(f'   - Query: "{ex}" → {card_linking_id}')
        for ex in activate_my_card_examples[:2]:
            examples_lines.append(f'   - Query: "{ex}" → {activate_my_card_id}')
        for ex in lost_or_stolen_card_examples[:2]:
            examples_lines.append(f'   - Query: "{ex}" → {lost_or_stolen_card_id}')
        examples_lines.append("")
    
    # 3. pin_blocked vs change_pin
    if pin_blocked_id is not None and change_pin_id is not None:
        pin_blocked_examples = additional_few_shot_data.get("pin_blocked", [])
        change_pin_examples = additional_few_shot_data.get("change_pin", [])
        if pin_blocked_examples and change_pin_examples:
            examples_lines.append(f"3. pin_blocked (ID {pin_blocked_id}) vs change_pin (ID {change_pin_id}):")
            examples_lines.append(f"   pin_blocked = PIN is locked/blocked, need to unlock")
            examples_lines.append(f"   change_pin = want to change PIN to a new one")
            for ex in pin_blocked_examples[:2]:
                examples_lines.append(f'   - Query: "{ex}" → {pin_blocked_id}')
            for ex in change_pin_examples[:2]:
                examples_lines.append(f'   - Query: "{ex}" → {change_pin_id}')
            examples_lines.append("")
    
    # 4. pending_cash_withdrawal vs declined_cash_withdrawal vs cash_withdrawal_not_recognised
    if (pending_cash_withdrawal_id is not None and declined_cash_withdrawal_id is not None and 
        cash_withdrawal_not_recognised_id is not None):
        pending_cw_examples = additional_few_shot_data.get("pending_cash_withdrawal", [])
        declined_cw_examples = additional_few_shot_data.get("declined_cash_withdrawal", [])
        cw_not_rec_examples = additional_few_shot_data.get("cash_withdrawal_not_recognised", [])
        if pending_cw_examples and declined_cw_examples and cw_not_rec_examples:
            examples_lines.append(f"4. pending_cash_withdrawal (ID {pending_cash_withdrawal_id}) vs declined_cash_withdrawal (ID {declined_cash_withdrawal_id}) vs cash_withdrawal_not_recognised (ID {cash_withdrawal_not_recognised_id}):")
            examples_lines.append(f"   pending_cash_withdrawal = withdrawal is processing/pending")
            examples_lines.append(f"   declined_cash_withdrawal = withdrawal was rejected/declined")
            examples_lines.append(f"   cash_withdrawal_not_recognised = withdrawal not showing in account")
            for ex in pending_cw_examples[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {pending_cash_withdrawal_id}')
            for ex in declined_cw_examples[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {declined_cash_withdrawal_id}')
            for ex in cw_not_rec_examples[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {cash_withdrawal_not_recognised_id}')
            examples_lines.append("")
    
    # 5. verify_my_identity vs why_verify_identity vs unable_to_verify_identity
    if (verify_my_identity_id is not None and why_verify_identity_id is not None and 
        unable_to_verify_identity_id is not None):
        verify_examples = additional_few_shot_data.get("verify_my_identity", [])
        why_verify_examples = additional_few_shot_data.get("why_verify_identity", [])
        unable_verify_examples = additional_few_shot_data.get("unable_to_verify_identity", [])
        if verify_examples and why_verify_examples and unable_verify_examples:
            examples_lines.append(f"5. verify_my_identity (ID {verify_my_identity_id}) vs why_verify_identity (ID {why_verify_identity_id}) vs unable_to_verify_identity (ID {unable_to_verify_identity_id}):")
            examples_lines.append(f"   verify_my_identity = want to verify/complete verification")
            examples_lines.append(f"   why_verify_identity = asking why verification is needed")
            examples_lines.append(f"   unable_to_verify_identity = having trouble completing verification")
            for ex in verify_examples[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {verify_my_identity_id}')
            for ex in why_verify_examples[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {why_verify_identity_id}')
            for ex in unable_verify_examples[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {unable_to_verify_identity_id}')
            examples_lines.append("")
    
    # 6. Exchange rate related confusions
    if (card_payment_wrong_exchange_rate_id is not None and wrong_exchange_rate_for_cash_withdrawal_id is not None and
        exchange_rate_id is not None):
        card_payment_wrong_ex = additional_few_shot_data.get("card_payment_wrong_exchange_rate", [])
        cash_wrong_ex = additional_few_shot_data.get("wrong_exchange_rate_for_cash_withdrawal", [])
        exchange_rate_ex = additional_few_shot_data.get("exchange_rate", [])
        if card_payment_wrong_ex and cash_wrong_ex and exchange_rate_ex:
            examples_lines.append(f"6. card_payment_wrong_exchange_rate (ID {card_payment_wrong_exchange_rate_id}) vs wrong_exchange_rate_for_cash_withdrawal (ID {wrong_exchange_rate_for_cash_withdrawal_id}) vs exchange_rate (ID {exchange_rate_id}):")
            examples_lines.append(f"   card_payment_wrong_exchange_rate = wrong rate used for CARD payment")
            examples_lines.append(f"   wrong_exchange_rate_for_cash_withdrawal = wrong rate used for CASH withdrawal")
            examples_lines.append(f"   exchange_rate = asking about current/general exchange rates")
            for ex in card_payment_wrong_ex[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {card_payment_wrong_exchange_rate_id}')
            for ex in cash_wrong_ex[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {wrong_exchange_rate_for_cash_withdrawal_id}')
            for ex in exchange_rate_ex[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {exchange_rate_id}')
            examples_lines.append("")
    
    # 7. Charge/fee related
    if extra_charge_on_statement_id is not None and card_payment_fee_charged_id is not None:
        extra_charge_ex = additional_few_shot_data.get("extra_charge_on_statement", [])
        card_fee_ex = additional_few_shot_data.get("card_payment_fee_charged", [])
        if extra_charge_ex and card_fee_ex:
            examples_lines.append(f"7. extra_charge_on_statement (ID {extra_charge_on_statement_id}) vs card_payment_fee_charged (ID {card_payment_fee_charged_id}):")
            examples_lines.append(f"   extra_charge_on_statement = unexpected charge on statement")
            examples_lines.append(f"   card_payment_fee_charged = fee charged for card payment")
            for ex in extra_charge_ex[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {extra_charge_on_statement_id}')
            for ex in card_fee_ex[:1]:
                examples_lines.append(f'   - Query: "{ex}" → {card_payment_fee_charged_id}')
            examples_lines.append("")
    
    # 8. getting_virtual_card
    if getting_virtual_card_examples:
        examples_lines.append(f"8. getting_virtual_card (ID {getting_virtual_card_id}):")
        for ex in getting_virtual_card_examples[:2]:
            examples_lines.append(f'   - Query: "{ex}" → {getting_virtual_card_id}')
        examples_lines.append("")
    
    few_shot_examples = "\n".join(examples_lines) if len(examples_lines) > 2 else ""

system_prompt = f"""You are a banking intent classifier. Classify the user's query into one of  {num_labels} banking intents (output is a single integer ID).
                   
IDs:

{id_mapping}

CRITICAL INSTRUCTIONS:
1. Choose exactly one integer ID (0-{num_labels-1}).
2. Reply with ONLY that number. No words, no reasoning, no punctuation.
Examples: 0, 1, 42

EXAMPLES TO HELP DISTINGUISH SIMILAR INTENTS:
{few_shot_examples}

Remember: Respond with ONLY the numeric ID, nothing else."""

print("System prompt created with numeric IDs and few-shot examples")
print(f"Number of labels: {num_labels}")
print(f"ID range: 0-{num_labels-1}")


system_prompt_basic = f"""You are a banking intent classifier. Classify the user's query into one of  {num_labels} banking intents (output is a single integer ID).
                   
IDs:

{id_mapping}

CRITICAL INSTRUCTIONS:
1. Choose exactly one integer ID (0-{num_labels-1}).
2. Reply with ONLY that number. No words, no reasoning, no punctuation.
Examples: 0, 1, 42

Remember: Respond with ONLY the numeric ID, nothing else."""

print(system_prompt)

System prompt created with numeric IDs and few-shot examples
Number of labels: 77
ID range: 0-76
You are a banking intent classifier. Classify the user's query into one of  77 banking intents (output is a single integer ID).

IDs:

0: activate_my_card
1: age_limit
2: apple_pay_or_google_pay
3: atm_support
4: automatic_top_up
5: balance_not_updated_after_bank_transfer
6: balance_not_updated_after_cheque_or_cash_deposit
7: beneficiary_not_allowed
8: cancel_transfer
9: card_about_to_expire
10: card_acceptance
11: card_arrival
12: card_delivery_estimate
13: card_linking
14: card_not_working
15: card_payment_fee_charged
16: card_payment_not_recognised
17: card_payment_wrong_exchange_rate
18: card_swallowed
19: cash_withdrawal_charge
20: cash_withdrawal_not_recognised
21: change_pin
22: compromised_card
23: contactless_not_working
24: country_support
25: declined_card_payment
26: declined_cash_withdrawal
27: declined_transfer
28: direct_debit_payment_not_recognised
29: disposable_card_limits

In [24]:
ds["test"]

Dataset({
    features: ['text', 'label'],
    num_rows: 3080
})

In [25]:
test_texts[0], test_labels[0]

('How do I locate my card?', 11)

In [None]:
# Save conversations to JSONL file for batch inference
import json
import os
# Output file path
output_jsonl = "test.jsonl"

# Use inference engine directly (recommended for system prompts)
from oumi.core.types.conversation import Conversation, Message, Role

# Create conversations with system prompt, user message, and assistant response (label)
conversations = [
    Conversation(
        messages=[
            Message(role=Role.SYSTEM, content=system_prompt),
            Message(role=Role.USER, content=text),
            Message(role=Role.ASSISTANT, content=str(label)),  # Label as string (integer ID)
        ]
    )
    for text, label in zip(test_texts[:TEST_SIZE], test_labels[:TEST_SIZE])
]



# Save conversations to JSONL format
# Each line is a JSON object representing one conversation
with open(output_jsonl, "w") as f:
    for conv in conversations:
        # Use to_dict() to convert Conversation to dict, then write as JSON line
        json.dump(conv.to_dict(), f)
        f.write("\n")

print(f"Saved {len(conversations)} conversations to {output_jsonl}")
print(f"File size: {os.path.getsize(output_jsonl) / 1024 / 1024:.2f} MB")

# Show first conversation as example
if conversations:
    print("\nFirst conversation example:")
    print(json.dumps(conversations[0].to_dict(), indent=2))


Saved 3080 conversations to test.jsonl
File size: 20.28 MB

First conversation example:
{
  "messages": [
    {
      "content": "You are a banking intent classifier. Classify the user's query into one of  77 banking intents (output is a single integer ID).\n\nIDs:\n\n0: activate_my_card\n1: age_limit\n2: apple_pay_or_google_pay\n3: atm_support\n4: automatic_top_up\n5: balance_not_updated_after_bank_transfer\n6: balance_not_updated_after_cheque_or_cash_deposit\n7: beneficiary_not_allowed\n8: cancel_transfer\n9: card_about_to_expire\n10: card_acceptance\n11: card_arrival\n12: card_delivery_estimate\n13: card_linking\n14: card_not_working\n15: card_payment_fee_charged\n16: card_payment_not_recognised\n17: card_payment_wrong_exchange_rate\n18: card_swallowed\n19: cash_withdrawal_charge\n20: cash_withdrawal_not_recognised\n21: change_pin\n22: compromised_card\n23: contactless_not_working\n24: country_support\n25: declined_card_payment\n26: declined_cash_withdrawal\n27: declined_transfer\n2

In [16]:
# CORRECTED: Convert banking77 data to SFT format WITH system prompt using numeric IDs
# This ensures training format matches inference format
import json

# Reuse the system prompt from Cell 7 (inference system prompt)
# This ensures training and inference use the exact same prompt
training_system_prompt = system_prompt

def convert_to_sft_format_with_system(dataset_split, label_names, system_prompt):
    """Convert banking77 dataset to SFT conversation format WITH system prompt.
    
    Uses numeric IDs instead of label names for more reliable parsing.
    """
    conversations = []
    for item in dataset_split:
        user_text = item["text"]
        label_idx = item["label"]  # This is already the numeric ID (0-76)
        
        # Use the numeric ID as string for the assistant response
        conversation = {
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_text},
                {"role": "assistant", "content": str(label_idx)}  # Use ID instead of label name
            ]
        }
        conversations.append(conversation)
    return conversations

# Re-convert with system prompt
print("Re-converting training data WITH system prompt...")
train_conversations = convert_to_sft_format_with_system(train_split, label_names, training_system_prompt)
val_conversations = convert_to_sft_format_with_system(val_split, label_names, training_system_prompt)

print(f"Converted {len(train_conversations)} training conversations (with system prompt)")
print(f"Converted {len(val_conversations)} validation conversations (with system prompt)")
print(f"\nSample conversation (with system prompt):")
print(json.dumps(train_conversations[0], indent=2))


# Re-save the corrected training data with system prompt
# This will overwrite the previous files
from pathlib import Path
output_dir = Path("data")
output_dir.mkdir(exist_ok=True)

train_path = output_dir / "train.jsonl"
val_path = output_dir / "validation.jsonl"

# Write train data with system prompt
with open(train_path, "w") as f:
    for conv in train_conversations:
        f.write(json.dumps(conv) + "\n")

# Write validation data with system prompt
with open(val_path, "w") as f:
    for conv in val_conversations:
        f.write(json.dumps(conv) + "\n")

print(f"Saved corrected training data (WITH system prompt + numeric IDs) to: {train_path}")
print(f"Saved corrected validation data (WITH system prompt + numeric IDs) to: {val_path}")
print(f"\n⚠️  IMPORTANT: You need to re-train the model with this corrected data!")
print(f"   Changes:")
print(f"   1. Training data now includes system prompt (matches inference)")
print(f"   2. Training data uses numeric IDs instead of label names (more reliable)")
print(f"   3. System prompt instructs model to output numeric IDs (0-76)")
print(f"\n   To retrain:")
print(f"   oumi train -c {output_dir / 'train.yaml'}")


Re-converting training data WITH system prompt...
Converted 8002 training conversations (with system prompt)
Converted 2001 validation conversations (with system prompt)

Sample conversation (with system prompt):
{
  "messages": [
    {
      "role": "system",
      "content": "You are a banking intent classifier. Classify the user's query into one of  77 banking intents (output is a single integer ID).\n\nIDs:\n\n0: activate_my_card\n1: age_limit\n2: apple_pay_or_google_pay\n3: atm_support\n4: automatic_top_up\n5: balance_not_updated_after_bank_transfer\n6: balance_not_updated_after_cheque_or_cash_deposit\n7: beneficiary_not_allowed\n8: cancel_transfer\n9: card_about_to_expire\n10: card_acceptance\n11: card_arrival\n12: card_delivery_estimate\n13: card_linking\n14: card_not_working\n15: card_payment_fee_charged\n16: card_payment_not_recognised\n17: card_payment_wrong_exchange_rate\n18: card_swallowed\n19: cash_withdrawal_charge\n20: cash_withdrawal_not_recognised\n21: change_pin\n22: 

# Baseline Inference

Qwen3-4B-Instruct-2507

## Run Inference on Cluster

1. **Submit inference job:**
   ```bash
   cd notebooks/scripts
   # Basic usage (uses defaults: data/test.jsonl, configs/4b_instruct_vllm_infer.yaml, output)
   ./submit_inference_rsync.sh
   
   # With custom parameters:
   # ./submit_inference_rsync.sh [input_file] [config_file] [output_name] [cluster_host]
   ./submit_inference_rsync.sh data/test.jsonl configs/4b_instruct_vllm_infer.yaml baseline_results ryan@exun
   ```

2. **Check job status:**
   ```bash
   ssh ryan@exun 'squeue -u ryan'
   ```

3. **View logs (optional):**
   ```bash
   ssh ryan@exun 'tail -f /home/ryan/code/oumi/lab/banking77/notebooks/logs/banking77_inference_qwen3_4b_*.log'
   ```

5. **Download results:**
   ```bash
   cd notebooks/scripts
   # Download most recent output (default: output_*.jsonl)
   ./download_output.sh
   
   # Download specific job ID
   ./download_output.sh 2739
   
   # Download with custom output name
   ./download_output.sh 2739 baseline_results
   # OR just find most recent with that name:
   ./download_output.sh "" baseline_results
   ```

The output will be saved to `notebooks/data/<output_name>_<JOB_ID>.jsonl`

In [18]:
# Reusable evaluation function for inference results
import json
import re
from typing import List, Tuple, Optional

def evaluate_predictions(
    output_file: str,
    test_labels: List[int],
    test_texts: List[str],
    label_names: List[str],
    num_labels: int,
    model_name: str = "Model"
) -> Tuple[List[Optional[int]], float, int, int]:
    """
    Evaluate predictions from an inference output JSONL file.
    
    Args:
        output_file: Path to the JSONL file with predictions
        test_labels: List of true label IDs
        test_texts: List of test texts (for display)
        label_names: List of label names
        num_labels: Number of labels
        model_name: Name of the model (for display)
    
    Returns:
        Tuple of (predictions, accuracy, correct_count, total_count)
    """
    predictions = []
    
    print(f"Reading predictions from {output_file}...")
    with open(output_file, "r") as f:
        for line in f:
            data = json.loads(line)
            # Extract the predicted label ID from the assistant's last message
            messages = data.get("messages", [])
            if messages and messages[-1].get("role") == "assistant":
                content = messages[-1].get("content", "").strip()
                # Extract first integer from the response (handles cases where model adds reasoning)
                match = re.search(r'\b(\d+)\b', content)
                if match:
                    try:
                        pred_id = int(match.group(1))
                        # Validate it's in the valid range
                        if 0 <= pred_id < num_labels:
                            predictions.append(pred_id)
                        else:
                            print(f"Warning: Prediction {pred_id} out of range [0, {num_labels-1}]")
                            predictions.append(None)
                    except ValueError:
                        print(f"Warning: Could not parse prediction from '{content[:50]}...'")
                        predictions.append(None)
                else:
                    print(f"Warning: No integer found in response: '{content[:50]}...'")
                    predictions.append(None)
            else:
                predictions.append(None)
    
    print(f"Loaded {len(predictions)} predictions")
    print(f"Test labels: {len(test_labels)}")
    
    # Calculate accuracy
    correct = sum(1 for pred, true in zip(predictions, test_labels) if pred is not None and pred == true)
    total = len(predictions)
    accuracy = correct / total if total > 0 else 0.0
    
    print(f"\n{model_name} Accuracy: {accuracy:.2%} ({correct}/{total})")
    
    # Show some examples
    print("\nSample predictions:")
    for i in range(min(10, len(predictions))):
        pred_label = label_names[predictions[i]] if predictions[i] is not None else "None"
        true_label = label_names[test_labels[i]]
        match = "✓" if predictions[i] == test_labels[i] else "✗"
        print(f"  {i+1}. {match} Pred: {pred_label:30s} True: {true_label:30s} | Text: {test_texts[i][:50]}...")
    
    return predictions, accuracy, correct, total

print("Evaluation function defined!")


Evaluation function defined!


# Evaluate 

In [19]:
# Evaluate baseline model results
output_file = "data/baseline_results_2740.jsonl"
baseline_predictions, baseline_accuracy, baseline_correct, baseline_total = evaluate_predictions(
    output_file=output_file,
    test_labels=test_labels,
    test_texts=test_texts,
    label_names=label_names,
    num_labels=num_labels,
    model_name="Baseline"
)


Reading predictions from data/baseline_results_2740.jsonl...
Loaded 3080 predictions
Test labels: 3080

Baseline Accuracy: 54.90% (1691/3080)

Sample predictions:
  1. ✗ Pred: lost_or_stolen_card            True: card_arrival                   | Text: How do I locate my card?...
  2. ✗ Pred: card_delivery_estimate         True: card_arrival                   | Text: I still have not received my new card, I ordered o...
  3. ✓ Pred: card_arrival                   True: card_arrival                   | Text: I ordered a card but it has not arrived. Help plea...
  4. ✗ Pred: card_delivery_estimate         True: card_arrival                   | Text: Is there a way to know when my card will arrive?...
  5. ✓ Pred: card_arrival                   True: card_arrival                   | Text: My card has not arrived yet....
  6. ✗ Pred: card_delivery_estimate         True: card_arrival                   | Text: When will I get my card?...
  7. ✓ Pred: card_arrival                   True: card_

# Training

# Banking77 Intent Classification

This notebook prepares data and provides instructions for training and inference on the Banking77 dataset.
## Training on Cluster
1.**Prepare training/validation datasets** (run cells below to generate JSONL files)

2. **Submit training job:**   

```bash   cd notebooks/scripts   # Basic usage (uses defaults: data/train.jsonl, data/validation.jsonl, configs/qwen4b_train_lora.yaml)   

./submit_training_rsync.sh      # With custom datasets:   ./submit_training_rsync.sh data/train.jsonl data/validation.jsonl configs/qwen4b_train_lora.yaml my_training ryan@exun      # With wandb project name:   

./submit_training_rsync.sh data/train.jsonl data/validation.jsonl configs/qwen4b_train_lora.yaml my_training ryan@exun my-wandb-project      # With wandb project and entity (team):   

./submit_training_rsync.sh data/train.jsonl data/validation.jsonl configs/qwen4b_train_lora.yaml my_training ryan@exun my-wandb-project my-team      # Or use environment variables:   export WANDB_PROJECT=my-wandb-project   

./submit_training_rsync.sh   ```

3. **Check job status:**   

```bash   ssh ryan@exun 'squeue -u ryan'   ```

4. **Kill job (if needed):**

   ```bash
   cd notebooks/scripts
   ./kill_training.sh  # Kills most recent training job
   # OR kill specific job:
   ./kill_training.sh 2744 ryan@exun
   ```

5. **View logs:**   

```bash   ssh ryan@exun 'tail -f /home/ryan/code/oumi/lab/banking77/notebooks/logs/banking77_training_qwen3_4b_*.log'   ```

**Note:** Make sure `WANDB_API_KEY` is set in your environment on the cluster for wandb logging to work.## Inference on Cluster

1.**Prepare test dataset** (run cells below to generate test.jsonl)

2. **Submit inference job:**   

```bash   cd notebooks/scripts   # Basic usage (uses defaults: data/test.jsonl, configs/4b_instruct_vllm_infer.yaml, output)   
./submit_inference_rsync.sh      # With custom parameters:   # 
./submit_inference_rsync.sh [input_file] [config_file] [output_name] [cluster_host]   
./submit_inference_rsync.sh data/test.jsonl configs/4b_instruct_vllm_infer.yaml baseline_results ryan@exun   ```

3. **Check job status:**   

```bash   ssh ryan@exun 'squeue -u ryan'   ```

4. **Download results:**   

```bash   cd notebooks/scripts   ./download_output.sh  # Downloads most recent output   # OR specify job ID and output name:   ./download_output.sh 2740 baseline_results   ```The output will be saved to `notebooks/data/<output_name>_<JOB_ID>.jsonl`

# Run inference on trained model checkpoint

## Run inference with LoRA adapter
cd notebooks/scripts
./submit_inference_rsync.sh \
    data/test.jsonl \
    configs/4b_instruct_vllm_infer.yaml \
    system_prompt_v2_lora_results \
    output/system_prompt_v2_lora_2757 \
    ryan@exun

## download the results
./download_output.sh 2760 system_prompt_v2_lora_results

# Evaluate finetune model

# Evaluate another baseline result file


In [21]:
output_file = "data/system_prompt_v2_lora_results_2760.jsonl"
baseline2_predictions, baseline2_accuracy, baseline2_correct, baseline2_total = evaluate_predictions(
    output_file=output_file,
    test_labels=test_labels,
    test_texts=test_texts,
    label_names=label_names,
    num_labels=num_labels,
    model_name="Baseline (2743)"
)

Reading predictions from data/system_prompt_v2_lora_results_2760.jsonl...
Loaded 3080 predictions
Test labels: 3080

Baseline (2743) Accuracy: 68.77% (2118/3080)

Sample predictions:
  1. ✗ Pred: card_not_working               True: card_arrival                   | Text: How do I locate my card?...
  2. ✓ Pred: card_arrival                   True: card_arrival                   | Text: I still have not received my new card, I ordered o...
  3. ✓ Pred: card_arrival                   True: card_arrival                   | Text: I ordered a card but it has not arrived. Help plea...
  4. ✗ Pred: card_delivery_estimate         True: card_arrival                   | Text: Is there a way to know when my card will arrive?...
  5. ✓ Pred: card_arrival                   True: card_arrival                   | Text: My card has not arrived yet....
  6. ✗ Pred: card_delivery_estimate         True: card_arrival                   | Text: When will I get my card?...
  7. ✓ Pred: card_arrival          