# LLMClassifier Test

Test the LLMClassifier on quark/gluon jet data using Google Gemini API.

This notebook demonstrates:
- Zero-shot jet classification with Gemini 2.5 Flash-Lite
- Thinking budget control (512-24,576 tokens for Flash-Lite)
- Token usage and cost tracking


In [9]:
%load_ext autoreload
%autoreload 2

import sys
import os
import numpy as np
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

from vibe_jet_tagging import LLMClassifier
from sklearn.metrics import roc_auc_score, accuracy_score


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load Data

Load the quark/gluon jet dataset.


In [10]:
# Load data
data_path = Path.cwd().parent / 'data' / 'qg_jets.npz'
data = np.load(data_path)

X = data['X']
y = data['y']

print(f"Loaded {len(X)} jets")
print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")
print(f"Quark jets: {(y == 1).sum()}")
print(f"Gluon jets: {(y == 0).sum()}")


Loaded 10000 jets
X shape: (10000, 139, 4)
y shape: (10000,)
Quark jets: 5074
Gluon jets: 4926


## Initialize LLMClassifier

Set up the classifier with Google Gemini API.

**Note:** You need a Gemini API key. Get one from [Google AI Studio](https://aistudio.google.com/app/apikey).

Set it as an environment variable:
```bash
export GEMINI_API_KEY="your-key-here"
```

Or add it to your `.env` file in the project root:
```
GEMINI_API_KEY='your-key-here'
```


In [11]:
# Load API key from .env file
from dotenv import load_dotenv
load_dotenv()

# Check if API key is set
if 'GEMINI_API_KEY' not in os.environ:
    print("WARNING: GEMINI_API_KEY not set. Please set it in .env file.")
    print("Create a .env file with: GEMINI_API_KEY='your-key-here'")
else:
    print("✓ Gemini API key found")


✓ Gemini API key found


In [12]:
# Initialize classifier with Gemini API
clf = LLMClassifier(
    model_name="gemini-2.5-flash-lite-preview-09-2025",  # No "google/" prefix
    template_name="simple_list",
    format_type="list",
    templates_dir=str(Path.cwd().parent / 'templates'),
    thinking_budget=1000,         # Control thinking tokens (512-24,576 for Flash-Lite)
    max_tokens=2000               # Max output tokens (need enough for thinking + output)
)

# Fit (no-op for zero-shot)
clf.fit([], [])

print("Classifier initialized")
print(f"Model: {clf.model_name}")
print(f"Template: {clf.template_name}")
print(f"Format: {clf.format_type}")
print(f"Thinking budget: {clf.thinking_budget}")


Classifier initialized
Model: gemini-2.5-flash-lite-preview-09-2025
Template: simple_list
Format: list
Thinking budget: 1000


## Test Single Jet Prediction


In [13]:
# Test on a single jet
test_jet = X[0]
true_label = y[0]

print(f"True label: {true_label} ({'quark' if true_label == 1 else 'gluon'})")
print(f"\nJet shape: {test_jet.shape}")
print(f"Number of particles (pt > 0): {(test_jet[:, 0] > 0).sum()}")

# Make prediction
prediction = clf.predict([test_jet], verbose=True)[0]
print(f"\nPredicted label: {prediction} ({'quark' if prediction == 1 else 'gluon'})")
print(f"Correct: {prediction == true_label}")


True label: 1.0 (quark)

Jet shape: (139, 4)
Number of particles (pt > 0): 18

🔧 API PARAMETERS
Model: gemini-2.5-flash-lite-preview-09-2025
Max output tokens: 2000
Thinking budget: 1000


────────────────────────────────────────────────────────────
📊 TOKEN USAGE
────────────────────────────────────────────────────────────
Prompt tokens:     757
Completion tokens: 1
Thinking tokens:   998
├─ Thinking:       998
└─ Output:         1
Total tokens:      1,756

💰 COST
Input cost:        $0.000057
Output cost:       $0.000300
Call cost:         $0.000356

✨ RESPONSE
────────────────────────────────────────────────────────────
Content: 0
────────────────────────────────────────────────────────────


════════════════════════════════════════════════════════════
📈 CUMULATIVE STATISTICS
════════════════════════════════════════════════════════════
Total prompt tokens:     757
Total completion tokens: 1
Total thinking tokens:   998
Total tokens:            1,756

💰 Total estimated cost: $0.000356


In [14]:
clf.preview_prompt(test_jet)

PROMPT PREVIEW
Model: gemini-2.5-flash-lite-preview-09-2025
Template: simple_list
Format: list
Max output tokens: 2000
Thinking budget: 1000

--------------------------------------------------------------------------------
PROMPT:
--------------------------------------------------------------------------------
You are a particle physics expert. Your task is to classify whether a jet is initiated by a quark (label: 1) or a gluon (label: 0).

A jet consists of particles, each with the following properties:
- pt: transverse momentum (GeV)
- y: rapidity
- phi: azimuthal angle (radians)
- pid: particle ID

Here is the jet data:
Particle 1: pt=0.269 GeV, y=0.357, phi=4.741, pid=22
Particle 2: pt=0.160 GeV, y=-0.256, phi=4.550, pid=22
Particle 3: pt=1.149 GeV, y=-0.062, phi=4.504, pid=-211
Particle 4: pt=4.132 GeV, y=0.174, phi=4.766, pid=-321
Particle 5: pt=1.696 GeV, y=-0.212, phi=4.797, pid=-211
Particle 6: pt=2.194 GeV, y=-0.052, phi=4.576, pid=22
Particle 7: pt=1.619 GeV, y=-0.068, phi=4

## Thinking Budget Control

Test how different thinking budgets affect performance and token usage.

For Gemini 2.5 Flash-Lite:
- **Minimum**: 512 tokens (or 0 to disable)
- **Maximum**: 24,576 tokens
- **Recommended**: 512-2000 for simple tasks, 2000-5000 for complex reasoning


In [15]:
# Test different thinking budgets on a single jet
test_jet = X[0]

budgets = [0, 512, 2000, 5000]
results = []

for budget in budgets:
    print(f"\n{'='*60}")
    print(f"Testing with thinking_budget={budget}")
    print(f"{'='*60}")
    
    clf_test = LLMClassifier(
        model_name="gemini-2.5-flash-lite-preview-09-2025",
        template_name="simple_list",
        format_type="list",
        templates_dir=str(Path.cwd().parent / 'templates'),
        thinking_budget=budget,
        max_tokens=3000
    )
    clf_test.fit([], [])
    
    pred = clf_test.predict([test_jet], verbose=True)[0]
    
    results.append({
        'budget': budget,
        'prediction': pred,
        'thinking_tokens': clf_test.total_thinking_tokens,
        'total_tokens': clf_test.total_prompt_tokens + clf_test.total_completion_tokens + clf_test.total_thinking_tokens,
        'cost': clf_test.total_cost
    })

# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
for r in results:
    print(f"Budget={r['budget']:5d}: Thinking={r['thinking_tokens']:4d}, Total={r['total_tokens']:5d}, Cost=${r['cost']:.6f}, Pred={r['prediction']}")



Testing with thinking_budget=0

🔧 API PARAMETERS
Model: gemini-2.5-flash-lite-preview-09-2025
Max output tokens: 3000
Thinking budget: 0


────────────────────────────────────────────────────────────
📊 TOKEN USAGE
────────────────────────────────────────────────────────────
Prompt tokens:     757
Completion tokens: 1
Total tokens:      758

💰 COST
Input cost:        $0.000057
Output cost:       $0.000000
Call cost:         $0.000057

✨ RESPONSE
────────────────────────────────────────────────────────────
Content: 0
────────────────────────────────────────────────────────────


════════════════════════════════════════════════════════════
📈 CUMULATIVE STATISTICS
════════════════════════════════════════════════════════════
Total prompt tokens:     757
Total completion tokens: 1
Total tokens:            758

💰 Total estimated cost: $0.000057
════════════════════════════════════════════════════════════


Testing with thinking_budget=512

🔧 API PARAMETERS
Model: gemini-2.5-flash-lite-previe

## Test on 100 Jets

Run the classifier on 100 jets and compute metrics.


In [53]:
# Select 100 jets
n_test = 10
X_test = X[:n_test]
y_test = y[:n_test]

print(f"Testing on {n_test} jets...")
print(f"True distribution: {(y_test == 1).sum()} quark, {(y_test == 0).sum()} gluon")


Testing on 10 jets...
True distribution: 9 quark, 1 gluon


In [54]:
# Make predictions (this will take a while)
from tqdm.auto import tqdm

predictions = []
for i, jet in enumerate(tqdm(X_test)):
    pred = clf.predict([jet])[0]
    predictions.append(pred)
    
    # Print progress every 10 jets
    if (i + 1) % 10 == 0:
        acc = accuracy_score(y_test[:i+1], predictions)
        print(f"After {i+1} jets: Accuracy = {acc:.3f}")

predictions = np.array(predictions)


  0%|          | 0/10 [00:00<?, ?it/s]

 50%|█████     | 5/10 [01:46<02:01, 24.36s/it]

Finish reason: length
Message object: ChatCompletionMessage(content='', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_details=[{'type': 'reasoning.encrypted', 'data': 'gAAAAABo5l1yp5st0hC078keJS1uCWcXtFOW8dT8bLoyFSTxxSPuMzk2i0g0r5ZHdTvsP4qc_6yYmBaDThaEcX-_HCpvhvUxWmTM9FDyX2iSZwJuFd5WPgeZeSB9OdzO4NXzjcH0N6VUUU85d96FbqaS0Q8vUYV7kspY62N1j4vJnmeBrt81lHdfZLU6G8HkyRf-DRJjr3Zw5SSF44IVOI_smKbN0tUAI4vkailXTHVzu-9w9lU_rMqoD0cquhB6hywyDXL7nwDgm2Fe4v8fzD7Bc2YuV0P6pa1kNEzVho_KlzJbjPKg2_URQFhwagmqcIUQM7v_dP_RSCQhiGP_5kGCMWpC0jdblj-MbTFSES5Gy_ldsssIBgPHg7DLVKSzd6RftpL69CMHG8BSnbo_wDM-YyxEVX9C4KXIZTB-HH-akRfg83AC9Hqa-z0nOowJ0EYKzb1D8qUSvHw8masll_JvYcW2BfdZUsA2MBCL_7ue7ysuL9v3iZ3Mjrg8EbCWF72_4N4OqNwn1s8jCdageH9-R5L5ygTdrmbJ_okp6bTmGRpfzScRSWbgd0wvvfPFU8iBwFwtO_VFb2IFLXxDun0eueioJz9eG5POYTb69d6vcNvgHMwuJaA8eWB4aD0ALRcLquwr8hdYz0TDBYwY4slNKs7iM-Aegryni_UGbaX8BcRXF9CVjDdNoqCDvzu6MXlnGdoraS9kfXeREfQhWgk0HdTR-Qhqbuub_n7t9SyetsIhYYtvBKiOxexqsVy2E

100%|██████████| 10/10 [02:59<00:00, 17.93s/it]

After 10 jets: Accuracy = 0.800





## Evaluate Performance


In [55]:
# Calculate metrics
accuracy = accuracy_score(y_test, predictions)
auc = roc_auc_score(y_test, predictions)

print("\n" + "="*50)
print("RESULTS")
print("="*50)
print(f"Accuracy: {accuracy:.3f}")
print(f"AUC Score: {auc:.3f}")
print(f"\nPredicted distribution: {(predictions == 1).sum()} quark, {(predictions == 0).sum()} gluon")
print(f"True distribution: {(y_test == 1).sum()} quark, {(y_test == 0).sum()} gluon")



RESULTS
Accuracy: 0.800
AUC Score: 0.889

Predicted distribution: 7 quark, 3 gluon
True distribution: 9 quark, 1 gluon


In [56]:
# Confusion matrix
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_test, predictions)
print("\nConfusion Matrix:")
print("                Predicted")
print("                Gluon  Quark")
print(f"True  Gluon     {cm[0,0]:5d}  {cm[0,1]:5d}")
print(f"      Quark     {cm[1,0]:5d}  {cm[1,1]:5d}")



Confusion Matrix:
                Predicted
                Gluon  Quark
True  Gluon         1      0
      Quark         2      7
