# Run inference on validaiton set using the model checkpoint

## Run inference
  
cd notebooks/scripts
./submit_inference_rsync.sh \
    data/banking77_val_no_label.jsonl \
    configs/4b_instruct_vllm_infer.yaml \
    system_prompt_v2_sft_results \
    output/banking77_qwen3_4b_full_3065 \
    ryan@exun


In [1]:
from utils import measure_accuracy, read_jsonl
from utils import (
    display_message,
    display_text
)

inference_file = '/Users/ryanarman/code/lab/banking77/notebooks/data/system_prompt_v2_sft_results_3072.jsonl'
inference_data = read_jsonl(inference_file)
accuracy, correct, total, errors, incorrect_list = measure_accuracy(inference_data)


Accuracy: 91.60% (1833/2001)


In [2]:
from collections import Counter

wihout_index = [sorted([item[1], item[2]]) for item in incorrect_list]
wihout_index_str = [f'{item[0]}-{item[1]}' for item in wihout_index]
cnt = Counter(wihout_index_str)
most_common_items = cnt.most_common(20)
most_common_pairs = [item[0] for item in most_common_items]

In [3]:
print(','.join(most_common_pairs[0:10]))

47-59,5-67,7-35,47-62,48-66,16-22,42-44,15-64,16-28,25-27


In [4]:
for pair in most_common_pairs:
    pair_split = pair.split('-')
    p1 = (int(pair_split[0]), int(pair_split[1]))
    p2 = (int(pair_split[1]), int(pair_split[0]))

    index_list_of_incorrect_pairs = []
    for item in incorrect_list:
        if item[1:] == p1 or item[1:] == p2:
            # print(item)
            index_list_of_incorrect_pairs.append(item[0])


    print(f"rows with incorrect pairs for {pair}")
    print(index_list_of_incorrect_pairs)
    print("-"*100)

rows with incorrect pairs for 47-59
[180, 656, 892, 903, 982, 1012, 1347, 1769]
----------------------------------------------------------------------------------------------------
rows with incorrect pairs for 5-67
[204, 636, 1040, 1052, 1395, 1431, 1583]
----------------------------------------------------------------------------------------------------
rows with incorrect pairs for 7-35
[272, 550, 585, 883, 1023, 1374]
----------------------------------------------------------------------------------------------------
rows with incorrect pairs for 47-62
[24, 890, 1004, 1222, 1300]
----------------------------------------------------------------------------------------------------
rows with incorrect pairs for 48-66
[281, 511, 1512, 1527, 1873]
----------------------------------------------------------------------------------------------------
rows with incorrect pairs for 16-22
[609, 653, 1073, 1868]
-----------------------------------------------------------------------------------

In [5]:
# Get label name mapping from the system prompt
import re

def get_label_name(label_id, system_message):
    """Extract label name for a given ID from system message."""
    pattern = rf'{label_id}:\s*([^\n]+)'
    match = re.search(pattern, system_message)
    if match:
        return match.group(1).strip()
    return f"Unknown_{label_id}"

# Investigate incorrect classifications for most common pairs
for pair_idx, pair in enumerate(most_common_pairs[:5]):  # Show top 5 pairs
    pair_split = pair.split('-')
    label1_id, label2_id = int(pair_split[0]), int(pair_split[1])
    
    # Get label names from first row's system message
    system_message = inference_data[0]['messages'][0]['content']
    label1_name = get_label_name(label1_id, system_message)
    label2_name = get_label_name(label2_id, system_message)
    
    print(f"\n{'='*120}")
    print(f"PAIR #{pair_idx + 1}: {label1_id} ({label1_name}) <-> {label2_id} ({label2_name})")
    print(f"{'='*120}\n")
    
    # Find all incorrect items for this pair
    p1 = (label1_id, label2_id)
    p2 = (label2_id, label1_id)
    
    incorrect_items_for_pair = []
    for item in incorrect_list:
        if item[1:] == p1 or item[1:] == p2:
            incorrect_items_for_pair.append(item)
    
    # Show first 5 examples
    print(f"Total misclassifications: {len(incorrect_items_for_pair)}")
    print(f"Showing first 5 examples:\n")
    
    for i, (row_idx, predicted_label, gt_label) in enumerate(incorrect_items_for_pair[:5]):
        row_data = inference_data[row_idx]
        user_query = row_data['messages'][1]['content']  # User message
        model_response = row_data['messages'][2]['content']  # Assistant response
        gt_label_name = row_data['metadata']['label_name']
        predicted_label_name = get_label_name(predicted_label, system_message)
        
        print(f"Example {i+1} (Row {row_idx}):")
        print(f"  Query: {user_query}")
        print(f"  Ground Truth: {gt_label} ({gt_label_name})")
        print(f"  Predicted: {predicted_label} ({predicted_label_name})")
        print(f"  Model Response: '{model_response}'")
        print()
    
    print("-"*120)
    print()



PAIR #1: 47 (pending_top_up) <-> 59 (top_up_failed)

Total misclassifications: 8
Showing first 5 examples:

Example 1 (Row 180):
  Query: I added money but my top-up wasn't processed by the app.
  Ground Truth: 59 (top_up_failed)
  Predicted: 47 (pending_top_up)
  Model Response: '47'

Example 2 (Row 656):
  Query: My top-up hasn't gone through
  Ground Truth: 47 (pending_top_up)
  Predicted: 59 (top_up_failed)
  Model Response: '59'

Example 3 (Row 892):
  Query: Why hasn't my top-up gone through?
  Ground Truth: 47 (pending_top_up)
  Predicted: 59 (top_up_failed)
  Model Response: '59'

Example 4 (Row 903):
  Query: The top-up is broken.
  Ground Truth: 47 (pending_top_up)
  Predicted: 59 (top_up_failed)
  Model Response: '59'

Example 5 (Row 982):
  Query: It appears my top-up has not gone through.
  Ground Truth: 47 (pending_top_up)
  Predicted: 59 (top_up_failed)
  Model Response: '59'

----------------------------------------------------------------------------------------------

In [6]:
len(incorrect_list)

168

In [7]:
inference_data[0]

{'conversation_id': '4d029644-36bd-5421-b126-225a16ebc3d5',
 'messages': [{'content': 'You are a banking intent classifier. Classify the user\'s query into one of  77 banking intents (output is a single integer ID).\n\nIDs:\n\n0: activate_my_card\n1: age_limit\n2: apple_pay_or_google_pay\n3: atm_support\n4: automatic_top_up\n5: balance_not_updated_after_bank_transfer\n6: balance_not_updated_after_cheque_or_cash_deposit\n7: beneficiary_not_allowed\n8: cancel_transfer\n9: card_about_to_expire\n10: card_acceptance\n11: card_arrival\n12: card_delivery_estimate\n13: card_linking\n14: card_not_working\n15: card_payment_fee_charged\n16: card_payment_not_recognised\n17: card_payment_wrong_exchange_rate\n18: card_swallowed\n19: cash_withdrawal_charge\n20: cash_withdrawal_not_recognised\n21: change_pin\n22: compromised_card\n23: contactless_not_working\n24: country_support\n25: declined_card_payment\n26: declined_cash_withdrawal\n27: declined_transfer\n28: direct_debit_payment_not_recognised\n29

In [8]:
from utils import generate_hard_examples_batch
import os

# Get label names from system message
system_message = inference_data[0]['messages'][0]['content']

# Set output file
output_dir = './data'
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, 'hard_examples_all_pairs.jsonl')

# Generate hard examples for all pairs in parallel
batch_result = generate_hard_examples_batch(
    pairs=most_common_pairs,
    get_label_name_func=get_label_name,
    system_message=system_message,
    num_examples=40,
    num_class_a=20,
    num_class_b=20,
    model="gpt-5",
    temperature=1.0,
    max_workers=1000,
    output_file=output_file,
    show_progress=True
)

Processing 20 pairs with 1000 workers...
Output file: ./data/hard_examples_all_pairs.jsonl

[1/20] ✓ pending_top_up <-> top_up_failed: 40 examples
[2/20] ✓ beneficiary_not_allowed <-> failed_transfer: 40 examples
[3/20] ✓ Refund_not_showing_up <-> request_refund: 40 examples
[4/20] ✓ wrong_amount_of_cash_received <-> wrong_exchange_rate_for_cash_withdrawal: 40 examples
[5/20] ✓ pending_top_up <-> topping_up_by_card: 40 examples
[6/20] ✓ lost_or_stolen_phone <-> passcode_forgotten: 40 examples
[7/20] ✓ transfer_not_received_by_recipient <-> transfer_timing: 40 examples
[8/20] ✓ activate_my_card <-> card_linking: 40 examples
[9/20] ✓ fiat_currency_support <-> supported_cards_and_currencies: 40 examples
[10/20] ✓ extra_charge_on_statement <-> transfer_fee_charged: 40 examples
[11/20] ✓ beneficiary_not_allowed <-> declined_transfer: 40 examples
[12/20] ✓ pending_transfer <-> transfer_not_received_by_recipient: 40 examples
[13/20] ✓ card_arrival <-> lost_or_stolen_card: 40 examples
[14/20] 

## Add system prompt
python add_system_prompt.py /Users/ryanarman/code/lab/banking77/notebooks/data/hard_examples_all_pairs.jsonl /Users/ryanarman/code/lab/banking77/notebooks/data/banking77_train_improved_hard.jsonl

