In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
from sentence_transformers import SentenceTransformer
from collections import defaultdict, Counter
import torch.nn.functional as F

# ==== Configuration ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tool_vocab = [
    # Basic Text Editing
    "cut", "copy", "paste", "select", "delete", "undo", "redo",
    
    # Text Formatting
    "bold", "italic", "underline", "strike_through", "font_size", "font_family", "text_color", "highlight",
    
    # Paragraph Formatting & Lists
    "align_left", "align_center", "align_right", "align_justify",
    "indent", "outdent", "bullet_list", "numbered_list", "line_spacing",
    
    # Document Management & Files
    "save", "save_as", "open", "new_document", "print", "export_pdf",
    
    # Find & Replace
    "find", "replace", "find_and_replace_all",
    
    # Insertions & Objects
    "insert_image", "insert_table", "insert_link", "insert_shape", "insert_chart", "insert_header", "insert_footer",
    
    # View & Navigation
    "zoom_in", "zoom_out", "page_layout", "read_mode", "web_layout",
    
    # Collaboration & Review
    "add_comment", "track_changes", "accept_change", "reject_change", "share_document",
    
    # Other Utilities
    "format_painter", "spell_check", "grammar_check", "word_count", "toggle_ruler"
]

tool_descriptions = {
    # Basic Text Editing
    "cut": "removes selected content and places it in the clipboard, typically followed by paste",
    "copy": "duplicates selected content to the clipboard without removing it, often followed by paste",
    "paste": "inserts content from the clipboard at the current cursor position or replaces selected content",
    "select": "highlights content (text, image, etc.) for further operation like cut, copy, delete, or formatting",
    "delete": "removes content permanently; can often be undone",
    "undo": "reverses the last action performed, crucial for correcting mistakes",
    "redo": "re-applies a previously undone action, moving forward in the action history",

    # Text Formatting
    "bold": "applies bold formatting to selected text, making it stand out",
    "italic": "applies italic formatting to selected text, often for emphasis or titles",
    "underline": "applies an underline to selected text, commonly used for links or emphasis",
    "strike_through": "draws a line through the middle of the selected text",
    "font_size": "changes the size of the selected text",
    "font_family": "changes the typeface or style of the selected text (e.g., Arial, Times New Roman)",
    "text_color": "changes the color of the selected text",
    "highlight": "applies a colored background to the selected text, like using a highlighter pen",

    # Paragraph Formatting & Lists
    "align_left": "aligns selected text or objects to the left margin",
    "align_center": "centers selected text or objects horizontally on the page",
    "align_right": "aligns selected text or objects to the right margin",
    "align_justify": "aligns text to both the left and right margins, adding space between words as needed",
    "indent": "increases the indentation of selected paragraphs or list items",
    "outdent": "decreases the indentation of selected paragraphs or list items",
    "bullet_list": "converts selected text into an unordered list with bullet points",
    "numbered_list": "converts selected text into an ordered list with numbers",
    "line_spacing": "adjusts the amount of vertical space between lines of text in a paragraph",

    # Document Management & Files
    "save": "stores the current state of the document to a file on disk",
    "save_as": "saves the current document with a different name or location",
    "open": "opens an existing document from a file",
    "new_document": "creates a blank new document, usually starting a fresh project",
    "print": "sends the current document to a printer for a hard copy",
    "export_pdf": "saves the document in PDF format, commonly used for sharing final versions",

    # Find & Replace
    "find": "opens a dialog to search for specific text within the document",
    "replace": "opens a dialog to find text and replace it with new text, often used after 'find'",
    "find_and_replace_all": "finds all occurrences of text and replaces them automatically",

    # Insertions & Objects
    "insert_image": "adds an image from a file into the document",
    "insert_table": "adds a structured grid of rows and columns to the document",
    "insert_link": "creates a hyperlink to a web page or location within the document",
    "insert_shape": "adds a geometric shape like a square or circle",
    "insert_chart": "adds a data visualization from a spreadsheet or other source",
    "insert_header": "adds content to the top margin of a document, appearing on every page",
    "insert_footer": "adds content to the bottom margin of a document, appearing on every page",

    # View & Navigation
    "zoom_in": "magnifies the view of the document, making content appear larger",
    "zoom_out": "reduces the magnification of the document, showing more content at once",
    "page_layout": "changes the view to show how the document will look when printed",
    "read_mode": "optimizes the view for reading, hiding toolbars and menus",
    "web_layout": "shows the document as a web page, without page breaks",

    # Collaboration & Review
    "add_comment": "inserts a comment bubble attached to a specific piece of text, often for collaboration",
    "track_changes": "activates a mode where all edits are marked for review",
    "accept_change": "applies a tracked change permanently to the document",
    "reject_change": "discards a tracked change, restoring the original text",
    "share_document": "opens a dialog to share the document with other users for collaboration",

    # Other Utilities
    "format_painter": "copies formatting from one piece of content and applies it to another",
    "spell_check": "initiates a check for spelling errors in the document",
    "grammar_check": "initiates a check for grammatical errors in the document",
    "word_count": "displays a count of words, characters, and pages in the document",
    "toggle_ruler": "shows or hides the horizontal and vertical rulers"
}

tool_patterns = {
    # Basic Editing
    "cut": ["paste", "undo", "copy", "select", "save"],
    "copy": ["paste", "undo", "select", "cut"],
    "paste": ["undo", "select", "cut", "bold", "italic", "format_painter"],
    "select": ["cut", "copy", "delete", "bold", "italic", "underline", "align_left", "format_painter"],
    "delete": ["undo", "select", "cut", "save"],
    "undo": ["redo", "cut", "copy", "paste", "delete", "save"],
    "redo": ["undo", "paste", "select", "bold", "italic"],
    "bold": ["italic", "underline", "select", "align_left"],
    "italic": ["bold", "underline", "select", "align_center"],
    "underline": ["bold", "italic", "select", "align_right"],

    # File Management
    "save": ["open", "new_document", "print", "export_pdf"],
    "save_as": ["save", "open", "print", "share_document"],
    "open": ["new_document", "save"],
    "new_document": ["save", "open", "insert_table", "insert_image"],
    "print": ["save", "new_document", "page_layout"],
    "export_pdf": ["save", "share_document"],

    # Find & Replace
    "find": ["replace", "copy", "delete", "find_and_replace_all"],
    "replace": ["find", "undo", "save"],
    "find_and_replace_all": ["undo", "save", "find"],

    # Formatting and Lists
    "align_left": ["align_center", "align_right", "align_justify", "select", "indent"],
    "align_center": ["align_left", "align_right", "align_justify", "select"],
    "align_right": ["align_left", "align_center", "align_justify", "select"],
    "align_justify": ["align_left", "align_center", "align_right", "select"],
    "indent": ["outdent", "bullet_list", "numbered_list"],
    "outdent": ["indent", "bullet_list", "numbered_list"],
    "bullet_list": ["numbered_list", "indent", "outdent", "select"],
    "numbered_list": ["bullet_list", "indent", "outdent", "select"],
    "format_painter": ["select", "paste", "bold", "italic"],
    "text_color": ["highlight", "select", "bold"],
    "highlight": ["text_color", "select"],

    # Insertions
    "insert_image": ["select", "cut", "copy", "delete"],
    "insert_table": ["select", "insert_link", "insert_chart"],
    "insert_link": ["select", "copy"],
    "insert_shape": ["select", "cut", "copy", "delete"],
    "insert_chart": ["insert_table", "select", "cut", "copy"],
    "insert_header": ["insert_footer", "page_layout", "insert_link"],
    "insert_footer": ["insert_header", "page_layout"],

    # View & Navigation
    "zoom_in": ["zoom_out", "read_mode", "page_layout"],
    "zoom_out": ["zoom_in", "page_layout", "web_layout"],
    "read_mode": ["page_layout", "web_layout", "zoom_in"],

    # Collaboration
    "add_comment": ["share_document", "track_changes", "select"],
    "track_changes": ["accept_change", "reject_change", "add_comment", "share_document"],
    "accept_change": ["reject_change", "track_changes", "save"],
    "reject_change": ["accept_change", "track_changes"],
    "share_document": ["save", "add_comment", "export_pdf"],
    
    # Utilities
    "spell_check": ["grammar_check", "undo", "save"],
    "grammar_check": ["spell_check", "undo", "save"],
    "word_count": ["save", "print"],
}

pad_token_id = 0
tool_to_id = {tool: idx + 1 for idx, tool in enumerate(tool_vocab)}
id_to_tool = {idx + 1: tool for idx, tool in enumerate(tool_vocab)}
vocab_size = len(tool_vocab) + 1
context_len = 6  # Increased context length

# ==== Enhanced Description Embeddings ====
desc_model = SentenceTransformer('all-MiniLM-L6-v2')
desc_texts = [tool_descriptions[tool] for tool in tool_vocab]
desc_embeddings = desc_model.encode(desc_texts, normalize_embeddings=True)
desc_id_to_embedding = {
    tool_to_id[tool]: torch.tensor(desc_embeddings[i], dtype=torch.float32).to(device)
    for i, tool in enumerate(tool_vocab)
}
desc_dim = desc_embeddings.shape[1]

# Add padding token embedding
desc_id_to_embedding[pad_token_id] = torch.zeros(desc_dim, dtype=torch.float32).to(device)

# ==== Enhanced Dataset with Pattern Mining ====
class EnhancedToolDataset(Dataset):
    def __init__(self, sequences, context_len, augment_data=True):
        self.samples = []
        self.pattern_freq = defaultdict(int)
        self.context_len = context_len
        
        # First pass: collect pattern frequencies
        for seq in sequences:
            token_ids = [tool_to_id[t] for t in seq]
            for i in range(1, len(token_ids)):
                context = token_ids[max(0, i - context_len):i]
                label = token_ids[i]
                pattern = tuple(context[-min(3, len(context)):])  # Use last 3 tokens as pattern
                self.pattern_freq[pattern] += 1
        
        # Second pass: create samples with weights
        for seq in sequences:
            token_ids = [tool_to_id[t] for t in seq]
            for i in range(1, len(token_ids)):
                context = token_ids[max(0, i - context_len):i]
                label = token_ids[i]
                context = [pad_token_id] * (context_len - len(context)) + context
                
                # Calculate sample weight based on pattern frequency
                pattern = tuple(context[-min(3, len(context)):])
                weight = 1.0 / (1 + self.pattern_freq[pattern] * 0.1)  # Reduce weight for common patterns
                
                self.samples.append((context, label, weight))
        
        # Data augmentation
        if augment_data:
            self._augment_data()
    
    def _augment_data(self):
        """Add synthetic samples based on known patterns"""
        augmented_samples = []
        for tool, likely_next in tool_patterns.items():
            base_context = [pad_token_id] * (self.context_len - 1) + [tool_to_id[tool]]
            for next_tool in likely_next:
                augmented_samples.append((base_context, tool_to_id[next_tool], 0.5))
        
        self.samples.extend(augmented_samples)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        context, label, weight = self.samples[idx]
        return (torch.tensor(context, dtype=torch.long), 
                torch.tensor(label, dtype=torch.long),
                torch.tensor(weight, dtype=torch.float32))

# ==== Enhanced Model Architecture ====
class EnhancedToolPredictor(nn.Module):
    def __init__(self, vocab_size, embed_dim, desc_dim, n_heads, num_layers, context_len, desc_emb_table, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.context_len = context_len
        self.desc_dim = desc_dim
        
        # Enhanced embeddings
        self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
        self.pos_embedding = nn.Embedding(context_len, embed_dim)
        self.desc_proj = nn.Linear(desc_dim, embed_dim)
        self.desc_emb_table = desc_emb_table
        
        # Layer normalization for embeddings
        self.embed_norm = nn.LayerNorm(embed_dim)
        
        # Enhanced transformer with residual connections
        self.transformer_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=embed_dim, 
                nhead=n_heads, 
                dropout=dropout,
                batch_first=True
            )
            for _ in range(num_layers)
        ])
        
        # Attention mechanism for context weighting
        self.context_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
        
        # Multiple prediction heads
        self.output_layer = nn.Linear(embed_dim, vocab_size)
        self.confidence_head = nn.Linear(embed_dim, 1)  # Confidence scoring
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, std=0.02)
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        
        # Token embeddings
        tok_emb = self.token_embedding(x)
        
        # Positional embeddings
        pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
        pos_emb = self.pos_embedding(pos_ids)
        
        # Description embeddings
        desc_emb = torch.stack([
            torch.stack([
                self.desc_proj(self.desc_emb_table.get(tok.item(), torch.zeros(self.desc_dim).to(x.device)))
                for tok in row
            ])
            for row in x
        ])
        
        # Combine embeddings
        x_emb = tok_emb + pos_emb + desc_emb
        x_emb = self.embed_norm(x_emb)
        x_emb = self.dropout(x_emb)
        
        # Create causal mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        
        # Pass through transformer layers
        hidden = x_emb
        for layer in self.transformer_layers:
            hidden = layer(hidden, hidden, tgt_mask=tgt_mask)
        
        # Context attention for final representation
        attn_out, _ = self.context_attention(hidden, hidden, hidden)
        final_hidden = hidden + attn_out  # Residual connection
        
        # Use last token for prediction
        last_hidden = final_hidden[:, -1, :]
        
        # Output predictions
        logits = self.output_layer(last_hidden)
        confidence = torch.sigmoid(self.confidence_head(last_hidden))
        
        return logits, confidence

# ==== Enhanced Training with Focal Loss ====
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

def train_model(model, dataloader, epochs=100, lr=1e-3, save_path="best_model.pth"):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    
    # Use focal loss for better handling of class imbalance
    loss_fn = FocalLoss(alpha=1, gamma=2)
    confidence_loss_fn = nn.BCELoss()
    
    model.train()
    best_loss = float('inf')
    
    for epoch in range(epochs):
        total_loss = 0
        total_samples = 0
        
        for batch_data in dataloader:
            context, label, weights = batch_data
            context, label, weights = context.to(device), label.to(device), weights.to(device)
            
            logits, confidence = model(context)
            
            # Main prediction loss
            pred_loss = loss_fn(logits, label)
            
            # Confidence loss (high confidence for correct predictions)
            pred_correct = (torch.argmax(logits, dim=-1) == label).float()
            conf_loss = confidence_loss_fn(confidence.squeeze(), pred_correct)
            
            # Weighted total loss
            total_loss_batch = pred_loss + 0.1 * conf_loss
            weighted_loss = (total_loss_batch * weights).mean()
            
            optimizer.zero_grad()
            weighted_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            total_loss += weighted_loss.item()
            total_samples += len(context)
        
        avg_loss = total_loss / len(dataloader)
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            print(f"Epoch {epoch+1}/{epochs}: New best model found with loss: {avg_loss:.4f}. Saving checkpoint...")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, save_path)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    return model

# ==== Enhanced Inference with Confidence ====
def predict_next_with_confidence(model, history, top_k=3):
    model.eval()
    with torch.no_grad():
        context = [tool_to_id[t] for t in history][-context_len:]
        context = [pad_token_id] * (context_len - len(context)) + context
        context = torch.tensor([context], dtype=torch.long).to(device)
        
        logits, confidence = model(context)
        probabilities = F.softmax(logits, dim=-1)
        
        # Get top-k predictions
        top_probs, top_indices = torch.topk(probabilities, top_k, dim=-1)
        
        predictions = []
        for i in range(top_k):
            tool_id = top_indices[0, i].item()
            prob = top_probs[0, i].item()
            conf = confidence[0, 0].item()
            tool_name = id_to_tool.get(tool_id, "<UNK>")
            predictions.append((tool_name, prob, conf))
        
        return predictions

# ==== Enhanced Data Generation ====
def generate_realistic_sequences(tool_vocab, tool_patterns, num_sequences=200, min_len=3, max_len=8):
    """Generate more realistic tool usage sequences"""
    sequences = []
    
    # Pattern-based generation
    for _ in range(num_sequences // 2):
        seq = []
        current_tool = random.choice(tool_vocab)
        seq.append(current_tool)
        
        length = random.randint(min_len, max_len)
        for _ in range(length - 1):
            if current_tool in tool_patterns:
                # 70% chance to follow pattern, 30% random
                if random.random() < 0.7:
                    next_tool = random.choice(tool_patterns[current_tool])
                else:
                    next_tool = random.choice(tool_vocab)
            else:
                next_tool = random.choice(tool_vocab)
            seq.append(next_tool)
            current_tool = next_tool
        
        sequences.append(seq)
    
    # Pure random generation
    for _ in range(num_sequences // 2):
        seq = random.choices(tool_vocab, k=random.randint(min_len, max_len))
        sequences.append(seq)
    
    return sequences

# ==== Common workflow patterns ====
common_workflows = [
    # 1. Drafting a New Document with Initial Formatting
    # A user starts a new document, titles it, saves it, and adds initial text formatting.
    ["new_document", "save_as", "select", "bold", "font_size", "align_center", "insert_header", "save"],

    # 2. Editing and Reorganizing a Report Section
    # The user opens a file, cuts a section, pastes it elsewhere, and then applies formatting.
    ["open", "select", "cut", "paste", "select", "align_justify", "line_spacing", "save"],

    # 3. Creating a List and Adjusting its Structure
    # A user creates a bulleted list, adds a sub-list, and then changes it to a numbered list.
    ["select", "bullet_list", "indent", "numbered_list", "outdent", "save"],

    # 4. Finalizing a Document After Find & Replace
    # The user performs a mass text replacement, then checks for errors before saving and exporting.
    ["find_and_replace_all", "undo", "spell_check", "grammar_check", "word_count", "save", "export_pdf"],
    
    # 5. Reviewing a Collaborative Document
    # A user opens a shared document, reviews changes, and adds a comment before sharing it again.
    ["share_document", "track_changes", "accept_change", "reject_change", "add_comment", "save"],

    # 6. Inserting and Formatting a Visual Element
    # The user inserts an image, adjusts its size, and then adds a formatted caption.
    ["insert_image", "select", "zoom_out", "insert_link", "align_center", "save"],
    
    # 7. Copying and Pasting Content with Formatting
    # The user copies a formatted section, pastes it, and then uses the format painter to apply the style to a new section.
    ["select", "copy", "paste", "select", "format_painter", "paste", "undo", "save"],
    
    # 8. Handling a File for Printing
    # A user saves a document, adjusts the page layout, zooms in to review, and then sends it to the printer.
    ["save", "page_layout", "zoom_in", "zoom_out", "print"],

    ["insert_table", "select", "paste", "bold", "align_center", "insert_link", "save"],

    # 10. Document Setup with Headers and Footers
    # The user focuses on document structure, adding headers and footers with a page layout change.
    ["new_document", "page_layout", "insert_header", "insert_footer", "save_as", "print"],
    
    # 11. Finalizing a Document with Error Checks
    # A user performs a full review, correcting errors and checking word count before saving.
    ["open", "find", "spell_check", "grammar_check", "word_count", "save", "export_pdf"],
    
    # 12. Working with Visuals and Shapes
    # The user inserts a chart, then adds an explanatory shape and text to highlight a point.
    ["insert_chart", "insert_shape", "select", "add_comment", "save"],
    
    # 13. Complex Text Formatting and Indentation
    # The user formats a paragraph with a mix of tools before adjusting its indentation.
    ["select", "text_color", "highlight", "underline", "align_justify", "indent", "outdent", "save"],
    
    # 14. Preparing a Document for Different View Modes
    # The user adjusts the zoom and switches between view modes to review the document for different purposes.
    ["zoom_in", "zoom_out", "read_mode", "page_layout", "save"],
    
    # 15. A User Reverting Multiple Actions
    # The user makes a series of changes, then decides to revert them using multiple undo actions.
    ["cut", "paste", "bold", "delete", "undo", "undo", "undo", "save"],

    # 16. Applying a Consistent Format Across a Document
    # A user selects text, formats it, and then repeatedly uses the format painter.
    ["select", "bold", "italic", "format_painter", "select", "format_painter", "select", "save"],
    
    # 17. Creating Different Versions of a Document
    # A user saves a document, then uses 'save as' to create a new version for a different purpose.
    ["open", "track_changes", "save", "save_as", "export_pdf"],
    
    # 18. Inserting and Editing a Hyperlink
    # The user selects text, adds a link, then revisits the link or its surrounding text.
    ["select", "insert_link", "select", "underline", "save"],
    
    # 19. Formatting for Readability
    # The user adjusts line spacing and text alignment for a paragraph.
    ["select", "line_spacing", "align_justify", "indent", "save"],
    
    # 20. Starting a New, Complex Project
    # A user begins a new project by inserting a header and a table, then saves the empty structure.
    ["new_document", "insert_header", "insert_table", "save_as"],
    
    # 21. Fine-tuning Collaborative Changes
    # The user is in review mode and selectively accepts and rejects changes.
    ["track_changes", "accept_change", "accept_change", "reject_change", "accept_change", "save"],
    
    # 22. Navigating and Reviewing the Document View
    # The user checks the document at different zoom levels and in different layouts.
    ["zoom_in", "page_layout", "zoom_out", "read_mode", "save"],
    
    # 23. A User Changing Their Mind and Correcting
    # A user formats a section, then deletes it and uses undo to retrieve it.
    ["select", "bold", "italic", "delete", "undo", "redo", "save"],

    # 24. A User Finding and Reusing Content
    # The user finds a specific word, copies it, and pastes it elsewhere.
    ["find", "copy", "paste", "save"],

    # 25. A Quick File Conversion and Sharing Task
    # The user opens a document and immediately exports it for sharing.
    ["open", "export_pdf", "share_document"],

    # 26. Building a Visual-Heavy Document
    # The user inserts a sequence of different visual elements and then saves.
    ["new_document", "insert_image", "insert_shape", "insert_chart", "save"],
    
    # 27. Correcting a List Hierarchy
    # The user converts a bulleted list to a numbered list and adjusts indentation, then reverts part of it.
    ["select", "bullet_list", "indent", "numbered_list", "undo", "save"],
    
    # 28. Simple Text Correction and Reformatting
    # The user corrects a typo with delete/undo and then applies basic formatting.
    ["delete", "undo", "select", "bold", "save"],

    # 29. Fine-tuning Document Layout
    # The user adjusts line spacing and switches between different layouts to see the impact.
    ["select", "line_spacing", "web_layout", "page_layout", "save"],
    
    # 30. A User Performing a Quick Formatting Change
    # A simple but common sequence where a user makes a quick, isolated change.
    ["select", "text_color", "save"],
]

# ==== Enhanced Sample Data ====
user_sequences = common_workflows * 3  # Repeat common patterns
user_sequences += generate_realistic_sequences(tool_vocab, tool_patterns, 300)

print(f"Generated {len(user_sequences)} training sequences")

# ==== Run Enhanced Training ====
dataset = EnhancedToolDataset(user_sequences, context_len, augment_data=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

print(f"Training samples: {len(dataset)}")

model = EnhancedToolPredictor(
    vocab_size=vocab_size,
    embed_dim=128,  # Increased embedding dimension
    desc_dim=desc_dim,
    n_heads=8,  # More attention heads
    num_layers=4,  # More layers
    context_len=context_len,
    desc_emb_table=desc_id_to_embedding,
    dropout=0.1
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

# Train the model
model = train_model(model, dataloader, epochs=100, lr=2e-4)

# ==== Enhanced Evaluation ====
def evaluate_model(model, test_sequences):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for seq in test_sequences:
            if len(seq) < 2:
                continue
            
            for i in range(1, len(seq)):
                history = seq[:i]
                true_next = seq[i]
                
                predictions = predict_next_with_confidence(model, history, top_k=1)
                predicted_tool = predictions[0][0]
                
                if predicted_tool == true_next:
                    correct += 1
                total += 1
    
    accuracy = correct / total if total > 0 else 0
    return accuracy

# Generate test sequences
test_sequences = generate_realistic_sequences(tool_vocab, tool_patterns, 50, min_len=3, max_len=6)
test_accuracy = evaluate_model(model, test_sequences)
print(f"Test Accuracy: {test_accuracy:.3f}")

# ==== Enhanced Prediction Examples ====
test_histories = [
    # Basic editing patterns
    ["select", "copy"],
    ["cut", "paste"],
    ["select", "delete"],
    ["undo", "undo"],
    
    # Text formatting workflows
    ["select", "bold"],
    ["select", "italic", "underline"],
    ["align_center"],
    ["font_size", "font_family"],
    
    # Document management
    ["new_document", "save_as"],
    ["open"],
    ["save", "print"],
    
    # Search and replace
    ["find", "replace"],
    
    # Insertions
    ["insert_image", "select"],
    ["insert_table", "insert_link"],
    
    # Lists and formatting
    ["bullet_list", "indent"],
    ["numbered_list", "outdent"],
    
    # Collaboration and review
    ["share_document", "add_comment"],
    ["track_changes", "accept_change"],
    
    # Complex, multi-step sequences
    ["select", "copy", "paste", "undo"],
    ["open", "find", "replace", "save"],
    ["new_document", "insert_header", "insert_footer"],
    ["save_as", "export_pdf", "share_document"],
]

print("\n=== Prediction Examples ===")
for history in test_histories:
    predictions = predict_next_with_confidence(model, history, top_k=3)
    print(f"History: {history}")
    for i, (tool, prob, conf) in enumerate(predictions):
        print(f"  {i+1}. {tool} (prob: {prob:.3f}, conf: {conf:.3f})")
    print()


  from tqdm.autonotebook import tqdm, trange
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



Using device: cuda


  attn_output = torch.nn.functional.scaled_dot_product_attention(


Generated 390 training sequences
Training samples: 1947
Model parameters: 2,768,185
Epoch 1/100: New best model found with loss: 3.5436. Saving checkpoint...
Epoch 1/100, Loss: 3.5436, LR: 0.000200
Epoch 2/100: New best model found with loss: 3.2496. Saving checkpoint...
Epoch 3/100: New best model found with loss: 2.9594. Saving checkpoint...
Epoch 4/100: New best model found with loss: 2.6776. Saving checkpoint...
Epoch 5/100: New best model found with loss: 2.5002. Saving checkpoint...
Epoch 6/100: New best model found with loss: 2.3041. Saving checkpoint...
Epoch 6/100, Loss: 2.3041, LR: 0.000200
Epoch 7/100: New best model found with loss: 2.1845. Saving checkpoint...
Epoch 8/100: New best model found with loss: 2.0165. Saving checkpoint...
Epoch 9/100: New best model found with loss: 1.9151. Saving checkpoint...
Epoch 10/100: New best model found with loss: 1.7832. Saving checkpoint...
Epoch 11/100: New best model found with loss: 1.6676. Saving checkpoint...
Epoch 11/100, Loss: 

In [1]:
# Library
import random 
import numpy
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset, DataLoader

  from tqdm.autonotebook import tqdm, trange
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



In [2]:
# Configure device to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Prepare dataset
tool_vocab = [
    "cut", "copy", "paste", "select", "delete", "undo", "redo",
    "bold", "italic", "underline", "align_left", "align_center", "align_right",
    "indent", "outdent", "bullet_list", "numbered_list", "find", "replace",
    "save", "open", "new_document", "print", "zoom_in", "zoom_out",
    "insert_image", "insert_table", "format_painter"
]
tool_descriptions = {
    "cut": "removes the selected content and places it in the clipboard, typically followed by paste",
    "copy": "duplicates the selected content to the clipboard without removing it, often followed by paste",
    "paste": "inserts content from the clipboard at the current cursor position or replaces selected content",
    "select": "highlights content (text, image, etc.) for further operation like cut, copy, delete, or formatting",
    "delete": "removes content permanently without placing it in the clipboard; can often be undone",
    "undo": "reverses the last action performed, crucial for correcting mistakes",
    "redo": "re-applies a previously undone action, moving forward in the action history",
    "bold": "applies bold formatting to selected text, making it stand out",
    "italic": "applies italic formatting to selected text, often for emphasis or titles",
    "underline": "applies an underline to selected text, commonly used for links or emphasis",
    "align_left": "aligns selected text or objects to the left margin",
    "align_center": "centers selected text or objects horizontally on the page",
    "align_right": "aligns selected text or objects to the right margin",
    "indent": "increases the indentation of selected paragraphs or list items, moving them further from the margin",
    "outdent": "decreases the indentation of selected paragraphs or list items, moving them closer to the margin",
    "bullet_list": "converts selected text into an unordered list with bullet points",
    "numbered_list": "converts selected text into an ordered list with numbers",
    "find": "opens a dialog to search for specific text within the document",
    "replace": "opens a dialog to find text and replace it with new text, often used after 'find'",
    "save": "stores the current state of the document to a file on disk",
    "open": "opens an existing document from a file",
    "new_document": "creates a blank new document, usually starting a fresh project",
    "print": "sends the current document to a printer for a hard copy",
    "zoom_in": "magnifies the view of the document, making content appear larger",
    "zoom_out": "reduces the magnification of the document, showing more content at once",
    "insert_image": "adds an image from a file into the document",
    "insert_table": "adds a structured grid of rows and columns to the document",
    "format_painter": "copies formatting from one piece of content and applies it to another"
}

tool_patterns = {
    "cut": ["paste", "undo", "copy", "select"],
    "copy": ["paste", "undo", "select"],
    "paste": ["undo", "select", "cut", "format_painter"],
    "select": ["cut", "copy", "delete", "bold", "italic", "underline", "align_left", "format_painter"],
    "delete": ["undo", "select", "cut"],
    "undo": ["redo", "cut", "copy", "paste", "delete", "save"],
    "redo": ["undo", "paste", "select", "bold", "italic"],
    "bold": ["italic", "underline", "select", "align_left"],
    "italic": ["bold", "underline", "select", "align_center"],
    "underline": ["bold", "italic", "select", "align_right"],
    "align_left": ["align_center", "align_right", "select"],
    "align_center": ["align_left", "align_right", "select"],
    "align_right": ["align_left", "align_center", "select"],
    "indent": ["outdent", "bullet_list", "numbered_list"],
    "outdent": ["indent", "bullet_list", "numbered_list"],
    "bullet_list": ["numbered_list", "indent", "outdent", "select"],
    "numbered_list": ["bullet_list", "indent", "outdent", "select"],
    "find": ["replace", "copy", "delete"],
    "replace": ["find", "undo", "save"],
    "save": ["new_document", "open", "print"],
    "open": ["save", "new_document", "print"],
    "new_document": ["save"],
    "print": ["save", "new_document", "open"],
    "zoom_in": ["zoom_out"],
    "zoom_out": ["zoom_in"],
    "insert_image": ["select", "delete"],
    "insert_table": ["select"],
    "format_painter": ["select", "paste"]
}

Using device: cuda


In [3]:
# Padding token id
pad_token_id = 0

# Configation for data 
tool_to_id = {tool : idx+1 for idx, tool in enumerate(tool_vocab)}
id_to_tool = {idx+1 : tool for idx, tool in enumerate(tool_vocab)}
vocab_size = len(tool_vocab) + 1
context_len = 6

# Description Embeddings 
desc_model = SentenceTransformer('all-MiniLM-L6-v2')
desc_texts = [tool_descriptions[tool] for tool in tool_vocab]
desc_embeddings = desc_model.encode(desc_texts, normalize_embeddings= True)
desc_id_to_embedding = {
    tool_to_id[tool] : torch.tensor(desc_embeddings[i], dtype= torch.float32).to(device) for i, tool in enumerate(tool_vocab)
    }
desc_emb = desc_embeddings.shape[1]

# Add padding token embedding 
desc_id_to_embedding[pad_token_id] = torch.zeros(desc_emb, dtype=torch.float32).to(device)


  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [36]:
class ToolDataset(Dataset):
    def __init__(self, sequences, context_len, augment_data=True):
        super().__init__()
        self.context_len = context_len
        self.samples = []
        self.pattern_freq = defaultdict(int)

        # Collect pattern frequencies 
        for seq in sequences:
            token_ids= [tool_to_id[tool] for tool in seq]
            for i in range(1, len(token_ids)):
                context = token_ids[max(i - context_len, 0):i]
                label = token_ids[i]
                pattern = tuple(context[-min(3, len(context)):])
                self.pattern_freq[pattern] += 1

        # Create samples with weights 
        for seq in sequences:
            token_ids = [tool_to_id[tool] for tool in seq]
            for i in range(1, len(token_ids)):
                context = token_ids[max(0, i - context_len):i]
                label = token_ids[i]
                context = [pad_token_id] * (context_len - len(context)) + context

                # Calculate weights based on sample frequency 
                pattern = tuple(context[-min(3, len(context)):])
                weight = 1 / (1 + self.pattern_freq[pattern] * 0.1)

                self.samples.append((context, label, weight))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        context, label, weight = self.samples[idx]
        return (torch.tensor(context, dtype=torch.long),
                torch.tensor(label, dtype=torch.long),
                torch.tensor(weight, dtype=torch.float32))
    
class ToolPredictor(nn.Module):
    def __init__(self, vocab_size, embed_dim, context_len, desc_dim, desc_emb_table, n_heads, num_layers, dropout=0.1):
        self.embed_dim = embed_dim
        self.desc_dim = desc_dim
        self.context_len = context_len

        # Embeddings 
        self.token_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx= pad_token_id)
        self.pos_embedding = nn.Embedding(context_len, embed_dim)
        self.desc_proj = nn.Linear(desc_dim, embed_dim)
        self.desc_emb_table = desc_emb_table

        # Layer normalization 
        self.embed_norm = nn.LayerNorm(embed_dim)

        # Transformer 
        self.transformer_layer = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model= embed_dim,
                nhead= n_heads,
                dropout= dropout,
                batch_first= True
            )
            for _ in range(num_layers)
        ])

        # Context Attention 
        self.context_attention = nn.MultiheadAttention(embed_dim= embed_dim, num_heads=n_heads, dropout= dropout, batch_first= True)

        # Multiple prediction heads 
        self.output = nn.Linear(embed_dim, vocab_size)
        self.confidence_head = nn.Linear(embed_dim, 1)

        # Dropout regularization 
        self.dropout = nn.Dropout(dropout)

        # Initialize weights 
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, std=0.02)

    def forward(self, x):
        batch_size, seq_len = x.size()

        # Token embeddings 
        tok_emb = self.token_embedding(x)

        # Positional Embeddings 
        pos_ids = torch.arange(seq_len, device= x.device).unsqueeze(0).expand(batch_size,-1)
        pos_emb = self.pos_embedding(pos_ids)

        # Description Embeddings 
        desc_emb = torch.stack([
            torch.stack([
                self.desc_proj(self.desc_emb_table.get(tok.item(), torch.zeros(self.desc_dim))) 
                for tok in row
                ]) 
                for row in x
            ])
        
        # Combine embeddings
        x_emb = tok_emb + pos_emb + desc_emb

        # Norm nad dropout
        x_emb = self.embed_norm(x_emb)
        self.dropout(x_emb)

        # Causal Mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)

        # Pass through transformer layer 
        hidden = x_emb
        for layer in self.transformer_layer:
            hidden= layer(hidden, hidden, tgt_mask= tgt_mask)

        # Context attention for final representation 
        attn_out = self.context_attention(hidden, hidden, hidden)

        # Final hidden layer
        final_hidden = attn_out + hidden
        
        # Use last tokens for prediction
        last_hidden = final_hidden[:,-1,:]

        # Output
        logits = self.output(last_hidden)
        confidence= torch.sigmoid(self.confidence_head(last_hidden))

        return logits, confidence
         

In [37]:
# Using focal loss as Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
# Training loop
def train_model(model, dataloader, epochs=50, lr=1e-3):
    optimizer= torch.optim.adamw(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

    # Use focal loss for better handling of class imbalance 
    loss_fn = FocalLoss(alpha=1, gamma=2)
    confidence_loss_fn = nn.BCELoss()

    model.train()
    best_loss = float('inf')

    for epoch in range(epochs):
        train_loss = 0
        total_samples = 0

        for batch_data in dataloader:
            context, label, weights = batch_data
            context, label, weights = context.to(device), label.to(device), weights.to(device)

            logits, confidence = model(context)

            # Prediction Loss 
            pred_loss = loss_fn(logits, label)

            # Confidence Loss (high confidence for correct prediction)
            pred_correct = 

            
