In [1]:
import torch
from transformers import DistilBertModel, DistilBertTokenizer

class AdaptiveTokenPruner:
    def __init__(self, threshold=0.1):
        self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.threshold = threshold
        
    def compute_token_importance(self, attention_scores):
        # Aggregate attention scores across heads
        importance = torch.mean(attention_scores, dim=1)
        return importance
        
    def prune_tokens(self, input_ids, attention_mask):
        # Get attention scores from first layer
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True
        )
        
        # Compute importance scores
        importance = self.compute_token_importance(outputs.attentions[0])
        
        # Create pruning mask
        keep_mask = importance > self.threshold
        
        return keep_mask

  torch.utils._pytree._register_pytree_node(
