# Custom "Next Tool Predictor" with Confidence Score

### 💡 Import Data and Libraries

In [1]:
# Import all libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
from sentence_transformers import SentenceTransformer
from collections import defaultdict, Counter
import torch.nn.functional as F

# Import data
from data import tool_vocab, tool_descriptions, tool_patterns, common_workflows, test_histories

  from tqdm.autonotebook import tqdm, trange
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



### 💡Set Configuration
 - Set device to "cuda" if cuda is available or default to "cpu"
 - Set pad_token_id to 0. Padding token is used as indice in place of padding 

In [None]:
# Configuration Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(f"Using device: {device}")

# Set padding token id
pad_token_id = 0

# Mappings
tool_to_id = {tool: idx + 1 for idx, tool in enumerate(tool_vocab)}
id_to_tool = {idx + 1: tool for idx, tool in enumerate(tool_vocab)}

# Adding +1 to tool_vocab to account for padding token
vocab_size = len(tool_vocab) + 1

# Setting context length to 6 for this demo
context_len = 6 

Using device: cuda


### 💡Generate Embeddings
- Using pretrained embedding model via sentence transformer for description embedding generation 

In [None]:
# Generate Description Embeddings
'''
Using sentence transformers
'''
# Using pre trained embedding model for embedding generation
desc_model = SentenceTransformer('all-MiniLM-L6-v2')
desc_texts = [tool_descriptions[tool] for tool in tool_vocab]
desc_embeddings = desc_model.encode(desc_texts, normalize_embeddings=True)
desc_id_to_embedding = {
    tool_to_id[tool]: torch.tensor(desc_embeddings[i], dtype=torch.float32).to(device)
    for i, tool in enumerate(tool_vocab)
}

# Print size of generated embeddings 
print(f"Shape of embeddings: {desc_embeddings.shape}") 
desc_dim = desc_embeddings.shape[1]

# Add padding token embedding
desc_id_to_embedding[pad_token_id] = torch.zeros(desc_dim, dtype=torch.float32).to(device)

Shape of embeddings: (55, 384)


  attn_output = torch.nn.functional.scaled_dot_product_attention(


### 💡Data Preparation

In [None]:
# Data Preparation
class ToolDataset(Dataset):
    def __init__(self, sequences, context_len, augment_data=True):
        self.samples = []
        self.pattern_freq = defaultdict(int)
        self.context_len = context_len
        
        # First pass: collect pattern frequencies
        for seq in sequences:
            token_ids = [tool_to_id[t] for t in seq]
            for i in range(1, len(token_ids)):
                context = token_ids[max(0, i - context_len):i]
                label = token_ids[i]
                
                # Using last 3 tokens as patterns
                pattern = tuple(context[-min(3, len(context)):])
                self.pattern_freq[pattern] += 1
        
        # Second pass: create samples with weights
        for seq in sequences:
            token_ids = [tool_to_id[t] for t in seq]
            for i in range(1, len(token_ids)):
                context = token_ids[max(0, i - context_len):i]
                label = token_ids[i]
                context = [pad_token_id] * (context_len - len(context)) + context
                
                # Calculate sample weight based on pattern frequency
                pattern = tuple(context[-min(3, len(context)):])
                # Reduce weight for common patterns
                weight = 1.0 / (1 + self.pattern_freq[pattern] * 0.1)  
                
                self.samples.append((context, label, weight))
    
    def __len__(self):
        # Returns length of samples 
        return len(self.samples)
    
    def __getitem__(self, idx):
        # Returns data present at an index
        context, label, weight = self.samples[idx]
        return (torch.tensor(context, dtype=torch.long), 
                torch.tensor(label, dtype=torch.long),
                torch.tensor(weight, dtype=torch.float32))


### 💡Custom Model Architecture

📌 Key Components & Design Highlights
**Input Embedding Layers:**

**1. Token Embedding**: Learns dense representations for input tokens.

**2. Positional Embedding**: Encodes positional information to preserve token order in sequences.

**3. Description Embedding**: Projects external, pre-computed description vectors (desc_emb_table) into the model's embedding space using a linear layer (desc_proj).

**Embedding Fusion & Normalization:**

- Combines token, positional, and description embeddings for each token.

- Apply LayerNorm and dropout for stabilization and regularization.

**Transformer Decoder Layers:**

- A stack of enhanced nn.TransformerDecoderLayer modules with residual connections.

- Each layer processes the input using masked self-attention to prevent future token leakage.

**Contextual Attention Mechanism:**

- A separate MultiheadAttention layer re-weights token-level hidden states to capture contextual importance.

- Output is added back to the hidden representation via a residual connection.

**Multiple Output Heads:**

- Logits Head: Generates token predictions using a linear projection.

- Confidence Head: Outputs a scalar confidence score (0 to 1) using a sigmoid activation.

**Weight Initialization:**

- Linear layers use Xavier uniform initialization for stability.

- Embeddings are initialized with a normal distribution.

In [None]:
# Custom Model Architecture 
class ToolPredictor(nn.Module):
    def __init__(self, vocab_size, embed_dim, desc_dim, n_heads, num_layers, context_len, desc_emb_table, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.context_len = context_len
        self.desc_dim = desc_dim
        
        # Enhanced embeddings
        self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
        self.pos_embedding = nn.Embedding(context_len, embed_dim)
        self.desc_proj = nn.Linear(desc_dim, embed_dim)
        self.desc_emb_table = desc_emb_table
        
        # Layer normalization for embeddings
        self.embed_norm = nn.LayerNorm(embed_dim)
        
        # Enhanced transformer with residual connections
        self.transformer_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=embed_dim, 
                nhead=n_heads, 
                dropout=dropout,
                batch_first=True
            )
            for _ in range(num_layers)
        ])
        
        # Attention mechanism for context weighting
        self.context_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
        
        # Multiple prediction heads
        self.output_layer = nn.Linear(embed_dim, vocab_size)

        # Confidence scoring
        self.confidence_head = nn.Linear(embed_dim, 1)  
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        # self.modules() is a built-in PyTorch method which finds every submodule in model
        for module in self.modules():
            if isinstance(module, nn.Linear):
                # If a module is a Linear layer, its weights are initialized using Xavier Uniform
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

            elif isinstance(module, nn.Embedding):
                # If a module is an Embedding layer, its weights are initialized from a normal distribution
                nn.init.normal_(module.weight, std=0.02)
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        
        # Token embeddings
        # Output size: (batch_size, seq_len, embed_dim)
        tok_emb = self.token_embedding(x)
        
        # Positional embeddings
        # Output size: (batch_size, seq_len, embed_dim)
        pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.pos_embedding(pos_ids)
        
        # Description embeddings
        # Output size: (batch_size, seq_len, embed_dim)
        desc_emb = torch.stack([
            torch.stack([
                self.desc_proj(self.desc_emb_table.get(tok.item(), torch.zeros(self.desc_dim).to(x.device)))
                for tok in row
            ])
            for row in x
        ])
        
        # Combine embeddings
        # Output size: (batch_size, seq_len, embed_dim)
        x_emb = tok_emb + pos_emb + desc_emb
        x_emb = self.embed_norm(x_emb)
        x_emb = self.dropout(x_emb)
        
        # Create causal mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        
        # Pass through transformer layers
        # Output size: (batch_size, seq_len, embed_dim)
        hidden = x_emb
        for layer in self.transformer_layers:
            hidden = layer(hidden, hidden, tgt_mask=tgt_mask)
        
        # Context attention for final representation
        # Output size: (batch_size, seq_len, embed_dim)
        attn_out, _ = self.context_attention(hidden, hidden, hidden)

        # Residual connection (Skip Connection)
        # Output size: (batch_size, seq_len, embed_dim)
        final_hidden = hidden + attn_out  
        
        # Use last token for prediction 
        # Output Size: (batch_size, embed_dim)
        last_hidden = final_hidden[:, -1, :]
        
        # Output Raw Logits
        # Output size: (batch_size, vocab_size)
        logits = self.output_layer(last_hidden)

        # Prediction Confidence
        # Output size: (batch_size, 1)
        confidence = torch.sigmoid(self.confidence_head(last_hidden)) 
        
        return logits, confidence

In [None]:
# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

### 💡Training Loop

📌 **Key Components & Techniques**
1. **Optimizer and Learning Rate Scheduling**
- Optimizer: Uses AdamW, a variation of Adam with weight decay, which improves generalization
- Scheduler: ReduceLROnPlateau reduces the learning rate when validation performance (loss) stagnates, helping fine-tune learning in later epochs

2. **Loss Functions**
- Focal Loss: Applied to main token prediction to focus learning on hard examples and handle class imbalance more effectively
- Binary Cross Entropy (BCE) Loss: Used for confidence prediction, encouraging the model to assign high confidence to correct predictions and lower confidence otherwise.

3. **Training Loop**
- The model is set to training mode (model.train()).
For each batch:
Inputs, labels, and sample weights are moved to the target device.
Forward pass yields both class logits and prediction confidence.
Two losses are computed:

- pred_loss: Main task loss from FocalLoss
- conf_loss: Confidence score loss using BCE against correctness of predictions

The total loss is a combination:
- weighted_loss = (pred_loss + 0.1 * conf_loss) * sample_weights

Backpropagation and optimization are performed with:

- Gradient clipping (max norm = 1.0) to prevent exploding gradients
- Zeroing gradients and calling .backward() and .step()


4. **Model Checkpointing**
- The model is saved whenever a new lowest average loss is achieved.
Saved checkpoint includes:
- Epoch number
- Model state
- Optimizer state
- Best loss

5. **Logging & Monitoring**
- Logs training loss and learning rate every 5 epochs.

Scheduler dynamically adjusts the learning rate based on average loss

In [None]:
# Setup training loop
def train_model(model, dataloader, epochs=100, lr=1e-3, save_path="best_model.pth"):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    
    # Use focal loss for better handling of class imbalance
    loss_fn = FocalLoss(alpha=1, gamma=2)
    confidence_loss_fn = nn.BCELoss()
    
    model.train()
    best_loss = float('inf')
    
    for epoch in range(epochs):
        total_loss = 0
        total_samples = 0
        
        for batch_data in dataloader:
            context, label, weights = batch_data
            context, label, weights = context.to(device), label.to(device), weights.to(device)
            
            logits, confidence = model(context)
            
            # Main prediction loss
            pred_loss = loss_fn(logits, label)
            
            # Confidence loss (high confidence for correct predictions)
            pred_correct = (torch.argmax(logits, dim=-1) == label).float()
            conf_loss = confidence_loss_fn(confidence.squeeze(), pred_correct)
            
            # Weighted total loss
            total_loss_batch = pred_loss + 0.1 * conf_loss
            weighted_loss = (total_loss_batch * weights).mean()
            
            optimizer.zero_grad()
            weighted_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            total_loss += weighted_loss.item()
            total_samples += len(context)
        
        avg_loss = total_loss / len(dataloader)
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            print(f"Epoch {epoch+1}/{epochs}: New best model found with loss: {avg_loss:.4f}. Saving checkpoint...")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, save_path)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    return model

In [None]:
# ==== Enhanced Inference with Confidence ====
def predict_next_with_confidence(model, history, top_k=3):
    model.eval()
    with torch.no_grad():
        context = [tool_to_id[t] for t in history][-context_len:]
        context = [pad_token_id] * (context_len - len(context)) + context
        context = torch.tensor([context], dtype=torch.long).to(device)
        
        logits, confidence = model(context)
        probabilities = F.softmax(logits, dim=-1)
        
        # Get top-k predictions
        top_probs, top_indices = torch.topk(probabilities, top_k, dim=-1)
        
        predictions = []
        for i in range(top_k):
            tool_id = top_indices[0, i].item()
            prob = top_probs[0, i].item()
            conf = confidence[0, 0].item()
            tool_name = id_to_tool.get(tool_id, "<UNK>")
            predictions.append((tool_name, prob, conf))
        
        return predictions

In [None]:
# ==== Enhanced Data Generation ====
def generate_realistic_sequences(tool_vocab, tool_patterns, num_sequences=200, min_len=3, max_len=8):
    """Generate more realistic tool usage sequences"""
    sequences = []
    
    # Pattern-based generation
    for _ in range(num_sequences // 2):
        seq = []
        current_tool = random.choice(tool_vocab)
        seq.append(current_tool)
        
        length = random.randint(min_len, max_len)
        for _ in range(length - 1):
            if current_tool in tool_patterns:
                # 70% chance to follow pattern, 30% random
                if random.random() < 0.7:
                    next_tool = random.choice(tool_patterns[current_tool])
                else:
                    next_tool = random.choice(tool_vocab)
            else:
                next_tool = random.choice(tool_vocab)
            seq.append(next_tool)
            current_tool = next_tool
        
        sequences.append(seq)
    
    # Pure random generation
    for _ in range(num_sequences // 2):
        seq = random.choices(tool_vocab, k=random.randint(min_len, max_len))
        sequences.append(seq)
    
    return sequences

In [None]:
# ==== Enhanced Sample Data ====
user_sequences = common_workflows * 3  # Repeat common patterns
user_sequences += generate_realistic_sequences(tool_vocab, tool_patterns, 300)

print(f"Generated {len(user_sequences)} training sequences")

# ==== Run Enhanced Training ====
dataset = ToolDataset(user_sequences, context_len, augment_data=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

print(f"Training samples: {len(dataset)}")

model = ToolPredictor(
    vocab_size=vocab_size,
    embed_dim=128,  # Increased embedding dimension
    desc_dim=desc_dim,
    n_heads=8,  # More attention heads
    num_layers=4,  # More layers
    context_len=context_len,
    desc_emb_table=desc_id_to_embedding,
    dropout=0.1
).to(device)

In [None]:
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

# Train the model
model = train_model(model, dataloader, epochs=100, lr=2e-4)

In [None]:
# ==== Enhanced Evaluation ====
def evaluate_model(model, test_sequences):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for seq in test_sequences:
            if len(seq) < 2:
                continue
            
            for i in range(1, len(seq)):
                history = seq[:i]
                true_next = seq[i]
                
                predictions = predict_next_with_confidence(model, history, top_k=1)
                predicted_tool = predictions[0][0]
                
                if predicted_tool == true_next:
                    correct += 1
                total += 1
    
    accuracy = correct / total if total > 0 else 0
    return accuracy

In [None]:
# Generate test sequences
test_sequences = generate_realistic_sequences(tool_vocab, tool_patterns, 50, min_len=3, max_len=6)
test_accuracy = evaluate_model(model, test_sequences)
print(f"Test Accuracy: {test_accuracy:.3f}")

print("\n=== Prediction Examples ===")
for history in test_histories:
    predictions = predict_next_with_confidence(model, history, top_k=3)
    print(f"History: {history}")
    for i, (tool, prob, conf) in enumerate(predictions):
        print(f"  {i+1}. {tool} (prob: {prob:.3f}, conf: {conf:.3f})")
    print()


  from tqdm.autonotebook import tqdm, trange
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



Using device: cuda


  attn_output = torch.nn.functional.scaled_dot_product_attention(


Generated 390 training sequences
Training samples: 1947
Model parameters: 2,768,185
Epoch 1/100: New best model found with loss: 3.5436. Saving checkpoint...
Epoch 1/100, Loss: 3.5436, LR: 0.000200
Epoch 2/100: New best model found with loss: 3.2496. Saving checkpoint...
Epoch 3/100: New best model found with loss: 2.9594. Saving checkpoint...
Epoch 4/100: New best model found with loss: 2.6776. Saving checkpoint...
Epoch 5/100: New best model found with loss: 2.5002. Saving checkpoint...
Epoch 6/100: New best model found with loss: 2.3041. Saving checkpoint...
Epoch 6/100, Loss: 2.3041, LR: 0.000200
Epoch 7/100: New best model found with loss: 2.1845. Saving checkpoint...
Epoch 8/100: New best model found with loss: 2.0165. Saving checkpoint...
Epoch 9/100: New best model found with loss: 1.9151. Saving checkpoint...
Epoch 10/100: New best model found with loss: 1.7832. Saving checkpoint...
Epoch 11/100: New best model found with loss: 1.6676. Saving checkpoint...
Epoch 11/100, Loss: 