In [None]:
# Function to extract attention patterns from BERT
def extract_attention(model, encoded_input):
    """Extract attention weights from the model."""
    with torch.no_grad():
        # Forward pass with output_attentions=True
        outputs = model(**encoded_input, output_attentions=True)
        
        # Get attention weights (tuple of tensors, one per layer)
        attentions = outputs.attentions
        
        return attentions

# Extract attention weights for our sample
attention_weights = extract_attention(model, sample_input)

print(f"Extracted attention weights from {len(attention_weights)} layers")
print(f"Shape of attention weights for first layer: {attention_weights[0].shape}")

## 6. Visualize attention patterns for a specific layer

In [None]:
# Select a layer to visualize (e.g., the first layer)
layer_idx = 0
layer_attention = attention_weights[layer_idx][0].detach().numpy()

# Get the tokens for visualization
tokens = tokenizer.convert_ids_to_tokens(sample_input['input_ids'][0])

# Plot the attention map
plt.figure(figsize=(10, 8))
plt.imshow(layer_attention[0], cmap='viridis')
plt.colorbar(label='Attention Weight')
plt.title(f'Attention Map for Layer {layer_idx}, Head 0')

# Label the axes with tokens (limit to first 15 tokens for readability)
max_tokens = min(15, len(tokens))
plt.xticks(range(max_tokens), tokens[:max_tokens], rotation=90)
plt.yticks(range(max_tokens), tokens[:max_tokens])

plt.tight_layout()
plt.show()

## 7. Compare attention patterns across different heads

In [None]:
# Compare attention patterns across different heads
n_heads = layer_attention.shape[0]
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()

for i, ax in enumerate(axes):
    if i < n_heads:
        ax.imshow(layer_attention[i][:max_tokens, :max_tokens], cmap='viridis')
        ax.set_title(f'Head {i}')
        
        # Only show token labels for the first head to avoid clutter
        if i == 0:
            ax.set_xticks(range(max_tokens))
            ax.set_yticks(range(max_tokens))
            ax.set_xticklabels(tokens[:max_tokens], rotation=90)
            ax.set_yticklabels(tokens[:max_tokens])
        else:
            ax.set_xticks([])
            ax.set_yticks([])
    else:
        ax.axis('off')

plt.tight_layout()
plt.suptitle(f'Attention Patterns Across Different Heads in Layer {layer_idx}', y=1.02)
plt.show()

## 8. Build a surrogate model for the classifier layer

In [None]:
# We need more samples to build a reliable surrogate
# Let's generate a few more sentences for our analysis
additional_texts = [
    "This movie exceeded all my expectations.",
    "I was disappointed by the quality of the product.",
    "The service was satisfactory, could be better.",
    "An amazing experience from start to finish!",
    "I regret purchasing this item, it broke quickly."
]

# Tokenize all texts
all_texts = sample_texts + additional_texts
all_encoded = []

for text in all_texts:
    encoded = tokenizer(
        text, 
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    all_encoded.append(encoded)

# Extract pooler output and predictions for all samples
from layerlens.core.layer_extractor import LayerExtractor

# Create a custom extractor function for BERT
def extract_bert_outputs(model, inputs_list):
    """Extract outputs from BERT layers."""
    outputs = {}
    pooler_outputs = []
    predictions = []
    
    # Process each input
    for encoded_input in inputs_list:
        with torch.no_grad():
            # Get pooler output
            pooler_output = model.bert.pooler(
                model.bert(**encoded_input).last_hidden_state
            )
            pooler_outputs.append(pooler_output.numpy())
            
            # Get predictions
            logits = model.classifier(pooler_output)
            predictions.append(logits.numpy())
    
    # Combine results
    outputs['bert.pooler'] = np.vstack(pooler_outputs)
    outputs['classifier'] = np.vstack(predictions)
    
    return outputs

# Extract outputs
layer_outputs = extract_bert_outputs(model, all_encoded)

# Build surrogate for the classifier layer
from layerlens.core.surrogate_builder import SurrogateBuilder

surrogate_builder = SurrogateBuilder(surrogate_type='linear')
classifier_surrogate = surrogate_builder.fit(
    'classifier',
    layer_outputs['bert.pooler'],
    layer_outputs['classifier']
)

print("Built surrogate model for the classifier layer")

## 9. Analyze feature importance from the surrogate model

In [None]:
# Get feature importances from the surrogate model
if hasattr(classifier_surrogate, 'coef_'):
    # For linear models, we can look at the coefficients
    # We'll focus on the positive sentiment class (index 1)
    positive_coefs = classifier_surrogate.coef_[1]
    
    # Plot top feature importances
    from layerlens.utils.plot_utils import plot_feature_importance
    
    importance_fig = plot_feature_importance(
        np.abs(positive_coefs),  # Use absolute value of coefficients
        feature_names=None,
        top_n=20
    )
    plt.title('Feature Importance for Positive Sentiment')
    plt.show()

## 10. Analyze token contributions to the prediction

In [None]:
# Create a simplified token attribution analysis
def analyze_token_importance(model, tokenizer, text, layer_idx=11):
    """Analyze which tokens contribute most to the prediction."""
    # Tokenize the text
    encoded = tokenizer(
        text, 
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    # Extract attention weights
    with torch.no_grad():
        outputs = model(**encoded, output_attentions=True)
        attentions = outputs.attentions
        
    # Get last layer attention
    last_layer_attention = attentions[layer_idx][0].mean(dim=0).detach().numpy()
    
    # Get the [CLS] token attention (how much each token influences the classification)
    cls_attention = last_layer_attention[0]
    
    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
    
    # Create a dictionary of token importances
    token_importance = {}
    for i, token in enumerate(tokens):
        if token not in ['[PAD]', '[CLS]', '[SEP]']:
            token_importance[token] = cls_attention[i]
    
    return token_importance, tokens, cls_attention

# Analyze our first sample text
token_importance, tokens, cls_attention = analyze_token_importance(
    model, tokenizer, sample_texts[0]
)

# Visualize token importance
plt.figure(figsize=(12, 6))
plt.bar(range(len(cls_attention)), cls_attention)
plt.xticks(range(len(cls_attention)), tokens, rotation=90)
plt.title('Token Importance based on Attention to [CLS] Token')
plt.xlabel('Token')
plt.ylabel('Attention Weight')
plt.tight_layout()
plt.show()

# Print top important tokens
print("Top 10 important tokens:")
sorted_tokens = sorted(token_importance.items(), key=lambda x: x[1], reverse=True)
for token, importance in sorted_tokens[:10]:
    print(f"{token}: {importance:.4f}")

## 11. Compare different text samples

In [None]:
# Compare predictions and explanations across different samples
predictions = []
top_tokens = []

for i, text in enumerate(sample_texts):
    # Get prediction
    encoded = tokenizer(
        text, 
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        outputs = model(**encoded)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=1)
    
    # Get token importance
    token_importance, _, _ = analyze_token_importance(model, tokenizer, text)
    top_tokens_for_sample = sorted(token_importance.items(), key=lambda x: x[1], reverse=True)[:5]
    
    predictions.append(probabilities[0][1].item())  # Probability of positive sentiment
    top_tokens.append(top_tokens_for_sample)

# Plot comparison
plt.figure(figsize=(10, 6))
bars = plt.bar(range(len(sample_texts)), predictions)

# Color-code by sentiment
for i, prob in enumerate(predictions):
    if prob > 0.5:
        bars[i].set_color('green')
    else:
        bars[i].set_color('red')

plt.xticks(range(len(sample_texts)), [f"Sample {i+1}" for i in range(len(sample_texts))])
plt.ylabel('Positive Sentiment Probability')
plt.title('Sentiment Predictions Across Samples')
plt.axhline(y=0.5, color='black', linestyle='--')
plt.ylim(0, 1)

for i, prob in enumerate(predictions):
    plt.text(i, prob + 0.05, f"{prob:.2f}", ha='center')

plt.tight_layout()
plt.show()

# Display top tokens for each sample
for i, tokens in enumerate(top_tokens):
    print(f"\nSample {i+1}: {sample_texts[i]}")
    print("Top influential tokens:")
    for token, importance in tokens:
        print(f"  {token}: {importance:.4f}")

## Conclusion

In this notebook, we've demonstrated how to use LayerLens to explain a transformer-based NLP model (BERT). We've:

1. Extracted and visualized attention patterns from different layers and heads
2. Built a surrogate model to explain the classifier layer
3. Analyzed feature importance for sentiment predictions
4. Identified which tokens contribute most to predictions

LayerLens helps us understand how transformer models process text and make predictions by providing layer-by-layer insights into their behavior.