# Perturbation-Based Feature Attribution

1. The Basic Idea

    "If I remove/change this token, how much does the prediction change?"

2. Step-by-Step Process

    a. Get Baseline: Run the model on the original sequence → get prediction + confidence 

    b. For each token position:
      
        -- Create a modified version (mask the token or replace with random)
        -- Run the model on this modified sequence
        -- Measure how much the confidence for the original prediction dropped
        -- Repeat this multiple times for statistical stability


3. Calculate Importance: Average the confidence drops across multiple perturbations



## The Two Perturbation Strategies

1. Masking: Removes information (what happens without this token?)
2. Random: Adds noise (what happens with meaningless information here?)



## Perturbation Method Mathematics
Question: "What happens to the prediction if I remove this specific token?"

Math: φᵢ = f(x) - E[f(x with token i masked/randomized)]
 
Where:

    f(x) = baseline prediction confidence
    E[f(x⁽ⁱ⁾)] = expected prediction confidence when feature i is perturbed

The expectation is over multiple random perturbations



In [None]:
# Simple explanation method that doesn't require SHAP
# This is much more memory-efficient and easier to use
# Enhanced version with token decoding

import torch
import numpy as np
import matplotlib.pyplot as plt
import gc

def explain_transformer_simple(model, input_sequence, tokenizer, n_perturbations=50, method='mask'):
    """
    Simple explanation method using perturbation analysis

    Args:
        model: Your trained transformer model
        input_sequence: Input token sequence to explain
        tokenizer: SentencePiece tokenizer for decoding tokens
        n_perturbations: Number of perturbations per token
        method: 'mask' (replace with 0) or 'random' (replace with random token)
    """
    device = next(model.parameters()).device
    model.eval()

    # Convert to tensor if needed
    if isinstance(input_sequence, np.ndarray):
        input_sequence = input_sequence.tolist()

    # Get baseline prediction
    input_tensor = torch.tensor([input_sequence], dtype=torch.long).to(device)
    with torch.no_grad():
        baseline_logics, _, attention_weights = model(input_tensor)
        baseline_probs = torch.softmax(baseline_logics[0, -1, :], dim=-1)
        baseline_prediction = torch.argmax(baseline_probs).item()
        baseline_confidence = baseline_probs[baseline_prediction].item()

    # Decode the predicted token
    predicted_token_text = tokenizer.decode([baseline_prediction])
    print(f"Baseline prediction: '{predicted_token_text}' (token {baseline_prediction}, confidence: {baseline_confidence:.4f})")

    importance_scores = np.zeros(len(input_sequence))

    # Test importance of each position
    for pos in range(len(input_sequence)):
        impacts = []

        for _ in range(n_perturbations):
            # Create perturbed sequence
            perturbed_seq = input_sequence.copy()

            if method == 'mask':
                perturbed_seq[pos] = 0  # Mask token
            else:  # random
                perturbed_seq[pos] = np.random.randint(1, 5000)

            # Get perturbed prediction
            perturbed_tensor = torch.tensor([perturbed_seq], dtype=torch.long).to(device)
            with torch.no_grad():
                perturbed_logics, _, _ = model(perturbed_tensor)
                perturbed_probs = torch.softmax(perturbed_logics[0, -1, :], dim=-1)
                perturbed_confidence = perturbed_probs[baseline_prediction].item()

                # Calculate impact (drop in confidence for original prediction)
                impact = baseline_confidence - perturbed_confidence
                impacts.append(impact)

        importance_scores[pos] = np.mean(impacts)

        # Progress indicator
        if (pos + 1) % 10 == 0 or pos == len(input_sequence) - 1:
            print(f"Progress: {pos + 1}/{len(input_sequence)} positions analyzed")

    return importance_scores, baseline_prediction, baseline_confidence, attention_weights, predicted_token_text

def plot_explanation_results(input_sequence, importance_scores, attention_weights,
                           predicted_token, predicted_token_text, confidence, tokenizer, top_k=15):
    """
    Plot explanation results with multiple views and decoded tokens
    """
    # Create figure with subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(18, 12))

    # 1. Token importance scores
    positions = range(len(importance_scores))
    colors = ['red' if score < 0 else 'blue' for score in importance_scores]

    ax1.bar(positions, importance_scores, color=colors, alpha=0.7)
    ax1.set_title(f'Token Importance Scores\nPredicted: "{predicted_token_text}" (conf: {confidence:.3f})')
    ax1.set_xlabel('Token Position')
    ax1.set_ylabel('Importance Score')
    ax1.grid(True, alpha=0.3)

    # 2. Top influential tokens with decoded text
    top_indices = np.argsort(np.abs(importance_scores))[-top_k:][::-1]
    top_scores = importance_scores[top_indices]
    top_colors = ['red' if score < 0 else 'blue' for score in top_scores]

    # Decode tokens for labels
    token_labels = []
    for idx in top_indices:
        token_id = input_sequence[idx]
        try:
            # Decode individual token
            decoded_token = tokenizer.decode([token_id])
            # Clean up the decoded token (remove special characters, limit length)
            if len(decoded_token) > 10:
                decoded_token = decoded_token[:7] + "..."
            # Replace newlines and tabs with visible characters
            decoded_token = decoded_token.replace('\n', '\\n').replace('\t', '\\t')
            token_labels.append(f"pos_{idx}\n'{decoded_token}'\n(id_{token_id})")
        except:
            # Fallback if decoding fails
            token_labels.append(f"pos_{idx}\n(id_{token_id})")

    ax2.barh(range(len(top_indices)), top_scores, color=top_colors, alpha=0.7)
    ax2.set_yticks(range(len(top_indices)))
    ax2.set_yticklabels(token_labels, fontsize=8)
    ax2.set_title(f'Top {top_k} Most Influential Tokens')
    ax2.set_xlabel('Importance Score')
    ax2.grid(True, alpha=0.3)

    # 3. Attention weights (last layer, first head)
    if attention_weights is not None and len(attention_weights) > 0:
        # Get attention from last layer, first head, for the last token
        last_layer_attention = attention_weights[-1][0, 0, -1, :].cpu().numpy()

        ax3.bar(positions, last_layer_attention, alpha=0.7, color='green')
        ax3.set_title('Attention Weights (Last Layer, Head 0)')
        ax3.set_xlabel('Token Position')
        ax3.set_ylabel('Attention Weight')
        ax3.grid(True, alpha=0.3)

        # 4. Importance vs Attention comparison
        ax4.scatter(last_layer_attention, importance_scores, alpha=0.7)
        ax4.set_xlabel('Attention Weight')
        ax4.set_ylabel('Importance Score')
        ax4.set_title('Importance vs Attention Correlation')
        ax4.grid(True, alpha=0.3)

        # Add correlation coefficient
        correlation = np.corrcoef(last_layer_attention, importance_scores)[0, 1]
        ax4.text(0.05, 0.95, f'Correlation: {correlation:.3f}',
                transform=ax4.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    else:
        ax3.text(0.5, 0.5, 'Attention weights not available',
                ha='center', va='center', transform=ax3.transAxes)
        ax4.text(0.5, 0.5, 'Attention weights not available',
                ha='center', va='center', transform=ax4.transAxes)

    plt.tight_layout()
    plt.show()

def analyze_explanation(input_sequence, importance_scores, predicted_token, predicted_token_text, tokenizer, top_k=10):
    """
    Provide detailed analysis of the explanation with decoded tokens
    """
    print(f"\n{'='*60}")
    print(f"EXPLANATION ANALYSIS")
    print(f"{'='*60}")

    print(f"Input sequence length: {len(input_sequence)}")
    print(f"Predicted next token: '{predicted_token_text}' (ID: {predicted_token})")
    print(f"Total importance score: {np.sum(importance_scores):.4f}")

    # Show a portion of the decoded input sequence
    try:
        decoded_input = tokenizer.decode(input_sequence[:50])  # First 50 tokens
        print(f"Input text (first ~50 tokens): '{decoded_input[:200]}{'...' if len(decoded_input) > 200 else ''}'")
    except:
        print("Could not decode input sequence")

    # Positive and negative influences
    positive_influence = np.sum(importance_scores[importance_scores > 0])
    negative_influence = np.sum(importance_scores[importance_scores < 0])

    print(f"Positive influence: {positive_influence:.4f}")
    print(f"Negative influence: {negative_influence:.4f}")

    # Top influential tokens
    sorted_indices = np.argsort(np.abs(importance_scores))[::-1]

    print(f"\nTop {top_k} most influential tokens:")
    print(f"{'Rank':<4} {'Position':<8} {'Token Text':<15} {'Token ID':<8} {'Score':<10} {'Type'}")
    print("-" * 70)

    for i, idx in enumerate(sorted_indices[:top_k]):
        try:
            token_text = tokenizer.decode([input_sequence[idx]])
            # Clean up token text for display
            if len(token_text) > 12:
                token_text = token_text[:9] + "..."
            token_text = token_text.replace('\n', '\\n').replace('\t', '\\t')
        except:
            token_text = "<?>"

        influence_type = "Helpful" if importance_scores[idx] > 0 else "Harmful"
        print(f"{i+1:<4} {idx:<8} '{token_text}'{'':.<{15-len(token_text)-2}} {input_sequence[idx]:<8} {importance_scores[idx]:<10.4f} {influence_type}")

    return {
        'top_indices': sorted_indices[:top_k],
        'positive_influence': positive_influence,
        'negative_influence': negative_influence,
        'total_influence': np.sum(importance_scores)
    }

# Clear GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("Starting simple transformer explanation with token decoding...")

# # Get sample sequence
# sample_idx = 100
# sample_sequence = train_token_ids[sample_idx:sample_idx+block_size].tolist()


sample_sequence = encoded.tolist()
print(f"Sample sequence length: {len(sample_sequence)}")
print(f"First 10 tokens: {sample_sequence[:10]}")

# Show decoded version of first few tokens
try:
    decoded_sample = sp.decode(sample_sequence[:20])
    print(f"First ~20 tokens decoded: '{decoded_sample}'")
except:
    print("Could not decode sample tokens")

# Run explanation (now requires tokenizer parameter)
print("\nComputing token importance scores...")
importance_scores, predicted_token, confidence, attention_weights, predicted_token_text = explain_transformer_simple(
    model,
    sample_sequence,
    sp,  # Pass your SentencePiece tokenizer
    n_perturbations=30,  # Adjust based on your patience/accuracy needs
    method='mask'  # or 'random'
)

print("Explanation complete!")

# Plot results with decoded tokens
plot_explanation_results(sample_sequence, importance_scores, attention_weights,
                        predicted_token, predicted_token_text, confidence, sp, top_k=15)

# Analyze results with decoded tokens
analysis = analyze_explanation(sample_sequence, importance_scores, predicted_token,
                             predicted_token_text, sp, top_k=10)

# Clean up memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("\nAnalysis complete!")
print("This method shows which tokens in your input sequence most influence the model's next token prediction.")
print("Positive scores = helpful for the prediction, Negative scores = harmful for the prediction")
print("Now with human-readable token text for better interpretation!")

# SHAP (SHapley Additive exPlanations)

SHAP is a method that enables a fast computation of Shapley values and can be used to explain the prediction of an instance x by computing the contribution (Shapley value) of each feature to the prediction. SHAP values are based on the concept of Shapley values from cooperative game theory. 


## Mathematical Difference from Perturbation Method

Question: "What is the fair contribution of this token across all possible contexts?"

Math: φᵢ = Σ [weighted marginal contributions across all possible token subsets]

<br>

**Perturbation is actually a special case of SHAP! Here's why:**
**SHAP considers all possible combinations of tokens:**

    -- What if token i joins an empty sequence?
    -- What if token i joins just token j?
    -- What if token i joins tokens j and k?
    -- What if token i joins ALL other tokens?

**Perturbation only looks at the last case - token i joining all other tokens.**



- SHAP systematically tries *every* coalition, measuring how token \(i\) helps in all possible “teams” of other tokens.
- Perturbation methods only consider a **single coalition**:
  - the “team” containing **all other tokens** except \(i\).

For example, consider the token “quick” in the phrase *"the quick brown fox"*:

- **SHAP** will test:
  - Coalition {} (empty): (f({quick}) - f({}))
  - Coalition {brown}: f({brown, quick}) - f({brown})
  - Coalition {the, fox}: f({the, quick, fox}) - f({the, fox})
  - Coalition {the, brown, fox}: f({the, brown, quick, fox}) - f({the, brown, fox})
  - … and every other subset not including feature i (i.e quick)

- **Perturbation** only looks at the “full context”:
  - f({the, brown, fox, quick}) - f({the, brown, fox})

Hence perturbation is equivalent to **one marginal contribution**, while SHAP aggregates **all** marginal contributions across all possible coalitions, providing a much fairer and more theoretically sound explanation.


**In summary:**
- **SHAP**: *“token’s contribution averaged over every possible subset of teammates”*
- **Perturbation**: *“token’s contribution in just the full team”*

 

# SHAP Approximate

SHAP (SHapley Additive exPlanations) provides a fair way to measure how each token contributes to the model’s prediction, inspired by cooperative game theory.


## 1. Approximate Shapley Value Estimation

We sample random subsets of tokens (coalitions).

- Let **x** be the token sequence of length *n*  
- Let **z** be a binary mask:
  - `z_i = 1` → token present
  - `z_i = 0` → token perturbed

Each coalition defines a perturbed sequence:

$$
\tilde{x}^{(j)} = x \odot z^{(j)} + p \odot (1 - z^{(j)})
$$

where:
- *p* is the perturbation baseline (e.g. masked token)
- $\odot$ is elementwise product

 

## 2. Model Output Collection

For each coalition:

$$
\hat{y}^{(j)} = f(\tilde{x}^{(j)})
$$

Measure change in confidence:

$$
c^{(j)} = \hat{y}^{(j)} - \hat{y}^{\text{baseline}}
$$

where $\hat{y}^{\text{baseline}} = f(x)$ is the original confidence.

 

## 3. Linear Shapley Approximation

Collect *m* such samples, and fit a simple regression:

$$
c^{(j)} \approx z^{(j)} \cdot \phi
$$

so that

$$
\phi \in \mathbb{R}^n
$$

is the approximate Shapley value vector.



## 4. Token Perturbation

If a token is removed $(z_i=0)$, it is replaced with:
- token id 0 (mask)
- or a random token



## 5. Summary

    Sample random coalitions  
    Perturb tokens  
    Measure change in prediction  
    Fit regression to get approximate Shapley values


## 6. Final Takeaway

**Linear SHAP** is a fast approximation:  
it fits a linear model over random coalitions  
to estimate how much each token helps the final confidence.



In [None]:
# Simple explanation method that doesn't require SHAP
# This is much more memory-efficient and easier to use
# Enhanced version with token decoding

import torch
import numpy as np
import matplotlib.pyplot as plt
import gc

import torch
import numpy as np
import matplotlib.pyplot as plt
import gc

def shap_approx_explanation(model, input_sequence, tokenizer, n_samples=100, method='mask'):
    """
    SHAP-style approximate explanation by sampling coalitions (subsets of tokens).

    Args:
        model: The transformer model.
        input_sequence: List of token ids.
        tokenizer: The tokenizer for decoding tokens.
        n_samples: Number of subset samples (coalitions) to generate.
        method: How to perturb tokens ('mask' or 'random').

    Returns:
        shap_values: Array of Shapley value estimates for each token position.
    """
    device = next(model.parameters()).device
    model.eval()

    input_sequence = np.array(input_sequence)
    num_tokens = len(input_sequence)
    
    # Baseline output (all tokens as-is)
    input_tensor = torch.tensor([input_sequence], dtype=torch.long).to(device)
    with torch.no_grad():
        base_logits, _, _ = model(input_tensor)
        base_probs = torch.softmax(base_logits[0, -1, :], dim=-1)
        base_pred = torch.argmax(base_probs).item()
        base_conf = base_probs[base_pred].item()
    
    print(f"Baseline prediction: '{tokenizer.decode([base_pred])}' (token {base_pred}, confidence: {base_conf:.4f})")

    # Prepare storage for coalition data
    coalition_matrix = []
    outputs = []

    for i in range(n_samples):
        # Random mask pattern: 1 = keep, 0 = perturb
        coalition = np.random.randint(0, 2, num_tokens)
        perturbed_seq = input_sequence.copy()

        for pos in range(num_tokens):
            if coalition[pos] == 0:
                if method == 'mask':
                    perturbed_seq[pos] = 0
                else:
                    perturbed_seq[pos] = np.random.randint(1, 5000)
        
        perturbed_tensor = torch.tensor([perturbed_seq], dtype=torch.long).to(device)
        with torch.no_grad():
            logits, _, _ = model(perturbed_tensor)
            probs = torch.softmax(logits[0, -1, :], dim=-1)
            conf = probs[base_pred].item()
        
        coalition_matrix.append(coalition)
        outputs.append(conf)
    
    coalition_matrix = np.array(coalition_matrix)
    outputs = np.array(outputs)

    # Solve weighted linear regression to get SHAP values
    # (no kernel weighting here for simplicity, but could be added)
    X = coalition_matrix
    y = outputs - base_conf  # deviation from baseline
    shap_values, _, _, _ = np.linalg.lstsq(X, y, rcond=None)

    return shap_values, base_pred, base_conf

# Example usage:
# shap_values, predicted_token, confidence = shap_approx_explanation(model, sample_sequence, sp, n_samples=200, method='mask')

def plot_explanation_results(input_sequence, importance_scores,
                             predicted_token, predicted_token_text, confidence,
                             tokenizer, top_k=15):
    """
    Plot SHAP-style explanation results with decoded tokens.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

    # 1. Token importance scores
    positions = range(len(importance_scores))
    colors = ['red' if score < 0 else 'blue' for score in importance_scores]

    ax1.bar(positions, importance_scores, color=colors, alpha=0.7)
    ax1.set_title(f'Token Importance Scores\nPredicted: "{predicted_token_text}" (conf: {confidence:.3f})')
    ax1.set_xlabel('Token Position')
    ax1.set_ylabel('Importance Score')
    ax1.grid(True, alpha=0.3)

    # 2. Top influential tokens with decoded text
    top_indices = np.argsort(np.abs(importance_scores))[-top_k:][::-1]
    top_scores = importance_scores[top_indices]
    top_colors = ['red' if score < 0 else 'blue' for score in top_scores]

    token_labels = []
    for idx in top_indices:
        token_id = input_sequence[idx]
        try:
            decoded_token = tokenizer.decode([token_id])
            if len(decoded_token) > 10:
                decoded_token = decoded_token[:7] + "..."
            decoded_token = decoded_token.replace('\n', '\\n').replace('\t', '\\t')
            token_labels.append(f"pos_{idx}\n'{decoded_token}'\n(id_{token_id})")
        except:
            token_labels.append(f"pos_{idx}\n(id_{token_id})")

    ax2.barh(range(len(top_indices)), top_scores, color=top_colors, alpha=0.7)
    ax2.set_yticks(range(len(top_indices)))
    ax2.set_yticklabels(token_labels, fontsize=8)
    ax2.set_title(f'Top {top_k} Most Influential Tokens')
    ax2.set_xlabel('Importance Score')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


def analyze_explanation(input_sequence, importance_scores, predicted_token, predicted_token_text, tokenizer, top_k=10):
    """
    Provide detailed analysis of the explanation with decoded tokens
    """
    print(f"\n{'='*60}")
    print(f"EXPLANATION ANALYSIS")
    print(f"{'='*60}")

    print(f"Input sequence length: {len(input_sequence)}")
    print(f"Predicted next token: '{predicted_token_text}' (ID: {predicted_token})")
    print(f"Total importance score: {np.sum(importance_scores):.4f}")

    try:
        decoded_input = tokenizer.decode(input_sequence[:50])
        print(f"Input text (first ~50 tokens): '{decoded_input[:200]}{'...' if len(decoded_input) > 200 else ''}'")
    except:
        print("Could not decode input sequence")

    positive_influence = np.sum(importance_scores[importance_scores > 0])
    negative_influence = np.sum(importance_scores[importance_scores < 0])

    print(f"Positive influence: {positive_influence:.4f}")
    print(f"Negative influence: {negative_influence:.4f}")

    sorted_indices = np.argsort(np.abs(importance_scores))[::-1]

    print(f"\nTop {top_k} most influential tokens:")
    print(f"{'Rank':<4} {'Position':<8} {'Token Text':<15} {'Token ID':<8} {'Score':<10} {'Type'}")
    print("-" * 70)

    for i, idx in enumerate(sorted_indices[:top_k]):
        try:
            token_text = tokenizer.decode([input_sequence[idx]])
            if len(token_text) > 12:
                token_text = token_text[:9] + "..."
            token_text = token_text.replace('\n', '\\n').replace('\t', '\\t')
        except:
            token_text = "<?>"

        influence_type = "Helpful" if importance_scores[idx] > 0 else "Harmful"
        print(f"{i+1:<4} {idx:<8} '{token_text}'{'':.<{15-len(token_text)-2}} {input_sequence[idx]:<8} {importance_scores[idx]:<10.4f} {influence_type}")

    return {
        'top_indices': sorted_indices[:top_k],
        'positive_influence': positive_influence,
        'negative_influence': negative_influence,
        'total_influence': np.sum(importance_scores)
    }


# Clear GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("Starting simple transformer explanation with token decoding...")

# # Get sample sequence
# sample_idx = 100
# sample_sequence = train_token_ids[sample_idx:sample_idx+block_size].tolist()


sample_sequence = encoded.tolist()
print(f"Sample sequence length: {len(sample_sequence)}")
print(f"First 10 tokens: {sample_sequence[:10]}")

# Show decoded version of first few tokens
try:
    decoded_sample = sp.decode(sample_sequence[:20])
    print(f"First ~20 tokens decoded: '{decoded_sample}'")
except:
    print("Could not decode sample tokens")

# Run explanation (now requires tokenizer parameter)
print("\nComputing token importance scores...")

# Compute SHAP-style explanation
importance_scores, predicted_token, confidence = shap_approx_explanation(
    model,
    sample_sequence,
    tokenizer=sp,
    n_samples=100,
    method='mask'
)
predicted_token_text = sp.decode([predicted_token])

# Plot
plot_explanation_results(
    sample_sequence,
    importance_scores,
    predicted_token,
    predicted_token_text,
    confidence,
    tokenizer=sp,
    top_k=15
)

# Analyze
analyze_explanation(
    sample_sequence,
    importance_scores,
    predicted_token,
    predicted_token_text,
    tokenizer=sp,
    top_k=10
)

# Clean up memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("\nAnalysis complete!")
print("This method shows which tokens in your input sequence most influence the model's next token prediction.")
print("Positive scores = helpful for the prediction, Negative scores = harmful for the prediction")
print("Now with human-readable token text for better interpretation!")