## Modifying Residual Connections

Testing code for accessing the representation at a given layer **before applying the residual connection**.


In [67]:
import torch
from transformers import BertModel, BertConfig
from transformers.models.bert.modeling_bert import BertLayer
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions


import torch
import torch.nn as nn
from transformers import BertConfig

from scipy.stats import pearsonr


## ChatGPT Version

In [68]:
#### ChatGPT version
# Define a custom model class
class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
    
    def forward(self, *args, **kwargs):
        # Call the original forward method to get the standard outputs
        outputs = super().forward(*args, **kwargs)
        
        # Extract hidden states from the original forward pass
        hidden_states = outputs.hidden_states  # List of hidden states at each layer
        
        # Lists to store pre-residual and post-residual outputs of each layer
        pre_residual_outputs = []
        post_residual_outputs = []
        
        # Iterate through each layer to compute pre-residual and post-residual outputs
        for i, layer in enumerate(self.encoder.layer):
            # Compute self-attention output (before adding residual connection)
            self_attention_outputs = layer.attention(hidden_states[i], output_attentions=False)
            attention_output = self_attention_outputs[0]
            pre_residual_attention_output = attention_output  # Save pre-residual attention output
            
            # Apply attention output with residual connection and normalization
            attention_output = layer.attention.output.LayerNorm(attention_output + hidden_states[i])
            
            # Compute feedforward output (before adding residual connection)
            intermediate_output = layer.intermediate(attention_output)
            layer_output = layer.output.dense(intermediate_output)
            pre_residual_layer_output = layer_output  # Save pre-residual layer output
            
            # Apply feedforward output with residual connection and normalization
            layer_output = layer.output.LayerNorm(layer_output + attention_output)
            
            # Store the pre-residual and post-residual outputs
            pre_residual_outputs.append(pre_residual_layer_output)
            post_residual_outputs.append(layer_output)
        
        # Return the standard outputs, pre-residual outputs, and post-residual outputs
        return outputs, pre_residual_outputs, post_residual_outputs

# Load the pre-trained model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = CustomBertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)

# Tokenize the input text
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors='pt')

# Forward pass
outputs, pre_residual_outputs, post_residual_outputs = model(**inputs)

# Extract the post-residual outputs and hidden states for a specific layer
layer_index = 5  # Example layer index
post_residual_layer_output = post_residual_outputs[layer_index][0]
pre_residual_layer_output = pre_residual_outputs[layer_index][0]
hidden_state_layer_output = outputs.hidden_states[layer_index + 1][0]  # hidden_states includes embedding layer as the first element

# Compare the post-residual outputs to the hidden states
print("Post-Residual Output for Layer", layer_index)
print(post_residual_layer_output)

print("\nHidden State Output for Layer", layer_index)
print(hidden_state_layer_output)

# Check if they are equal
print("\nAre the outputs equal?", torch.allclose(post_residual_layer_output, hidden_state_layer_output, atol=1e-6))

# Also print the pre-residual output for clarity
print("\nPre-Residual Output for Layer", layer_index)
print(pre_residual_layer_output)

Post-Residual Output for Layer 5
tensor([[ 0.0367, -0.3417, -0.3950,  ...,  0.0788,  0.1991,  0.3832],
        [ 0.9718,  0.1975,  0.7590,  ...,  0.6820,  0.1647, -0.4499],
        [ 0.1348,  0.4053,  0.0572,  ..., -0.2338, -0.1639,  0.6239],
        ...,
        [-0.2164, -1.4882,  0.8183,  ..., -0.1463, -0.0400,  0.0578],
        [ 0.3149, -0.7261, -0.5962,  ...,  0.0103,  0.0065,  0.6139],
        [ 0.0138, -0.0125, -0.0064,  ...,  0.0026, -0.0205, -0.0142]],
       grad_fn=<SelectBackward0>)

Hidden State Output for Layer 5
tensor([[ 0.0839, -0.9179, -0.5961,  ...,  0.0639,  0.3273,  0.3681],
        [ 0.7318, -0.1368,  1.1241,  ...,  0.6588, -0.2409, -0.5797],
        [-0.1293,  0.5957,  0.2409,  ..., -0.5293, -0.3646,  0.5766],
        ...,
        [-0.3583, -1.9956,  0.9719,  ..., -0.1385,  0.1476, -0.3508],
        [ 0.3562, -1.0461, -0.9223,  ...,  0.0445, -0.2199,  0.5077],
        [ 0.0170, -0.0387, -0.0153,  ...,  0.0079, -0.0167, -0.0436]],
       grad_fn=<SelectBackward0>

In [69]:
# Compute the overall correlation between the two tensors using scipy
post_residual_flat = post_residual_layer_output.flatten().detach().numpy()
hidden_state_flat = hidden_state_layer_output.flatten().detach().numpy()

correlation, _ = pearsonr(post_residual_flat, hidden_state_flat)
print("Pearson Correlation between the post/hidden tensors:", correlation)

pre_residual_flat = pre_residual_layer_output.flatten().detach().numpy()

correlation, _ = pearsonr(pre_residual_flat, hidden_state_flat)
print("Pearson Correlation between the pre/hidden tensors:", correlation)

correlation, _ = pearsonr(pre_residual_flat, post_residual_flat)
print("Pearson Correlation between the pre/post tensors:", correlation)

Pearson Correlation between the post/hidden tensors: 0.9362243522476544
Pearson Correlation between the pre/hidden tensors: 0.281897231078784
Pearson Correlation between the pre/post tensors: 0.2666938847810285


In [70]:
# Check if the two tensors have the same sign for each element
same_sign = torch.sign(post_residual_layer_output) == torch.sign(hidden_state_layer_output)

# Compute the proportion of elements with the same sign
proportion_same_sign = torch.mean(same_sign.float())
print("Proportion of post/hidden elements with Same Sign:", proportion_same_sign.item())

same_sign = torch.sign(pre_residual_layer_output) == torch.sign(hidden_state_layer_output)

# Compute the proportion of elements with the same sign
proportion_same_sign = torch.mean(same_sign.float())
print("Proportion of pre/hidden elements with Same Sign:", proportion_same_sign.item())

same_sign = torch.sign(pre_residual_layer_output) == torch.sign(post_residual_layer_output)

# Compute the proportion of elements with the same sign
proportion_same_sign = torch.mean(same_sign.float())
print("Proportion of pre/post elements with Same Sign:", proportion_same_sign.item())

Proportion of post/hidden elements with Same Sign: 0.8889973759651184
Proportion of pre/hidden elements with Same Sign: 0.60009765625
Proportion of pre/post elements with Same Sign: 0.6072590947151184


In [73]:
print("Post residual shape:", post_residual_layer_output.shape)
print("Pre residual shape:", pre_residual_layer_output.shape)
print("Hidden state shape:", hidden_state_layer_output.shape)

Post residual shape: torch.Size([8, 768])
Pre residual shape: torch.Size([8, 768])
Hidden state shape: torch.Size([8, 768])


## Claude Version

In [86]:
class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
        # Call the original forward method to get the standard outputs
        outputs = super().forward(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            output_hidden_states=True,  # Enable output of all hidden states
        )
        
        # Extract hidden states from the original forward pass
        hidden_states = outputs.hidden_states  # List of hidden states at each layer
        
        # Lists to store pre-residual and post-residual outputs of each layer
        pre_residual_outputs = []
        post_residual_outputs = []
        
        # Iterate through each layer to compute pre-residual and post-residual outputs
        for i, layer in enumerate(self.encoder.layer):
            # Compute self-attention output (before adding residual connection)
            self_attention_outputs = layer.attention(hidden_states[i], attention_mask, head_mask[i] if head_mask is not None else None, output_attentions=False)
            attention_output = self_attention_outputs[0]
            
            # Apply attention output with residual connection and normalization
            attention_output = layer.attention.output.LayerNorm(attention_output + hidden_states[i])
            
            # Compute feedforward output (before adding residual connection)
            intermediate_output = layer.intermediate(attention_output)
            layer_output = layer.output.dense(intermediate_output)
            pre_residual_layer_output = layer_output  # Save pre-residual layer output
            
            # Apply feedforward output with residual connection and normalization
            layer_output = layer.output.LayerNorm(layer_output + attention_output)
            
            # Store the pre-residual and post-residual outputs
            pre_residual_outputs.append(pre_residual_layer_output)
            post_residual_outputs.append(layer_output)
        
        # Create a new output object with all the original outputs plus our new pre_residual and post_residual outputs
        new_outputs = BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=outputs.last_hidden_state,
            pooler_output=outputs.pooler_output,
            hidden_states=outputs.hidden_states,
            past_key_values=outputs.past_key_values,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
        
        new_outputs.pre_residual_outputs = pre_residual_outputs
        new_outputs.post_residual_outputs = post_residual_outputs

        return new_outputs

# Usage example
config = BertConfig.from_pretrained('bert-base-uncased')
model = CustomBertModel.from_pretrained('bert-base-uncased', config=config)

# Forward pass
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
outputs = model(input_ids)

# Compare outputs
def compare_outputs(layer_idx):
    pre_residual = outputs.pre_residual_outputs[layer_idx]
    post_residual = outputs.post_residual_outputs[layer_idx]
    current_hidden_state = outputs.hidden_states[layer_idx]  # This is the input to the current layer
    next_hidden_state = outputs.hidden_states[layer_idx + 1]  # This is the output of the current layer

    print(f"Layer {layer_idx}:")
    print(f"Pre-residual shape: {pre_residual.shape}")
    print(f"Post-residual shape: {post_residual.shape}")
    print(f"Current hidden state shape: {current_hidden_state.shape}")
    print(f"Next hidden state shape: {next_hidden_state.shape}")
    
    print("\nComparisons:")
    print(f"Post-residual == Next hidden state: {torch.allclose(post_residual, next_hidden_state, atol=1e-5)}")
    print(f"Current hidden state == Input to layer: {torch.allclose(current_hidden_state, outputs.hidden_states[layer_idx], atol=1e-5)}")
    print(f"Pre-residual == Post-residual: {torch.allclose(pre_residual, post_residual, atol=1e-5)}")
    print(f"Mean difference (Post - Pre): {(post_residual - pre_residual).mean().item():.4f}")
    print(f"Std of difference (Post - Pre): {(post_residual - pre_residual).std().item():.4f}")
    print("\n")


    print("\nCorrelations:")
    # Compute the overall correlation between the two tensors using scipy
    post_residual_flat = post_residual.flatten().detach().numpy()
    next_hidden_state_flat = next_hidden_state.flatten().detach().numpy()
    current_hidden_state_flat = current_hidden_state.flatten().detach().numpy()
    pre_residual_flat = pre_residual.flatten().detach().numpy()
    
    post_next_hidden_correlation, _ = pearsonr(post_residual_flat, next_hidden_state_flat)
    post_current_hidden_correlation, _ = pearsonr(post_residual_flat, current_hidden_state_flat)
    pre_post_correlation, _ = pearsonr(post_residual_flat, pre_residual_flat)
    pre_next_hidden_correlation, _ = pearsonr(pre_residual_flat, next_hidden_state_flat)
    pre_current_hidden_correlation, _ = pearsonr(pre_residual_flat, current_hidden_state_flat)
    
    print("Pearson Correlation between the post/next-hidden tensors:", post_next_hidden_correlation)
    print("Pearson Correlation between the post/current-hidden tensors:", post_current_hidden_correlation)
    print("Pearson Correlation between the pre/next-hidden tensors:", pre_next_hidden_correlation)
    print("Pearson Correlation between the pre/current-hidden tensors:", pre_current_hidden_correlation)
    print("Pearson Correlation between the pre/post tensors:", pre_post_correlation)
    print("\n")

# Compare for first and last layer
compare_outputs(0)  # First layer
compare_outputs(len(outputs.hidden_states) - 2) 

Layer 0:
Pre-residual shape: torch.Size([1, 5, 768])
Post-residual shape: torch.Size([1, 5, 768])
Current hidden state shape: torch.Size([1, 5, 768])
Next hidden state shape: torch.Size([1, 5, 768])

Comparisons:
Post-residual == Next hidden state: False
Current hidden state == Input to layer: True
Pre-residual == Post-residual: False
Mean difference (Post - Pre): 0.0108
Std of difference (Post - Pre): 0.9014



Correlations:
Pearson Correlation between the post/next-hidden tensors: 0.903245657406737
Pearson Correlation between the post/current-hidden tensors: 0.6284135594108322
Pearson Correlation between the pre/next-hidden tensors: 0.42370215965868574
Pearson Correlation between the pre/current-hidden tensors: -0.13993228592411655
Pearson Correlation between the pre/post tensors: 0.3484587061685561


Layer 11:
Pre-residual shape: torch.Size([1, 5, 768])
Post-residual shape: torch.Size([1, 5, 768])
Current hidden state shape: torch.Size([1, 5, 768])
Next hidden state shape: torch.Siz

In [75]:

# Usage example
config = BertConfig.from_pretrained('bert-base-uncased')
model = CustomBertModel.from_pretrained('bert-base-uncased', config=config)

# Forward pass
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
outputs = model(input_ids)



In [76]:
# Compare outputs
def compare_outputs(layer_idx):
    intermediate_output = outputs.intermediate_outputs[layer_idx]
    layer_output = outputs.layer_outputs[layer_idx]
    hidden_state = outputs.hidden_states[layer_idx + 1]  # +1 because hidden_states includes the embedding layer
    attention_output = outputs.attention_outputs[layer_idx]

    print(f"Layer {layer_idx}:")
    print(f"Intermediate (FFN) output shape: {intermediate_output.shape}")
    print(f"Layer output shape: {layer_output.shape}")
    print(f"Hidden state shape: {hidden_state.shape}")
    print(f"Attention output shape: {attention_output.shape}")
    
    print("\nComparisons:")
    print(f"Layer output == Hidden state: {torch.allclose(layer_output, hidden_state)}")
    print(f"Layer output == Attention output: {torch.allclose(layer_output, attention_output)}")
    print(f"Intermediate output size vs Layer output size: {intermediate_output.shape[2]} vs {layer_output.shape[2]}")
    print("\n")

# Compare for first and last layer
compare_outputs(0)  # First layer
compare_outputs(-1)  # Last layer


Layer 0:
Intermediate (FFN) output shape: torch.Size([1, 5, 3072])
Layer output shape: torch.Size([1, 5, 768])
Hidden state shape: torch.Size([1, 5, 768])
Attention output shape: torch.Size([1, 5, 768])

Comparisons:
Layer output == Hidden state: False
Layer output == Attention output: False
Intermediate output size vs Layer output size: 3072 vs 768


Layer -1:
Intermediate (FFN) output shape: torch.Size([1, 5, 3072])
Layer output shape: torch.Size([1, 5, 768])
Hidden state shape: torch.Size([1, 5, 768])
Attention output shape: torch.Size([1, 5, 768])

Comparisons:
Layer output == Hidden state: False
Layer output == Attention output: False
Intermediate output size vs Layer output size: 3072 vs 768




In [77]:
# Additional analysis
def analyze_layer_transformation(layer_idx):
    intermediate_output = outputs.intermediate_outputs[layer_idx]
    layer_output = outputs.layer_outputs[layer_idx]
    attention_output = outputs.attention_outputs[layer_idx]

    print(f"Layer {layer_idx} Transformation Analysis:")
    print(f"Mean of intermediate output: {intermediate_output.mean().item():.4f}")
    print(f"Std of intermediate output: {intermediate_output.std().item():.4f}")
    print(f"Mean of layer output: {layer_output.mean().item():.4f}")
    print(f"Std of layer output: {layer_output.std().item():.4f}")
    print(f"Mean difference (layer - attention): {(layer_output - attention_output).mean().item():.4f}")
    print(f"Std of difference (layer - attention): {(layer_output - attention_output).std().item():.4f}")
    print("\n")

# Analyze transformations for first and last layer
analyze_layer_transformation(0)  # First layer
analyze_layer_transformation(-1)  # Last layer

Layer 0 Transformation Analysis:
Mean of intermediate output: -0.0204
Std of intermediate output: 0.1660
Mean of layer output: -0.0235
Std of layer output: 0.6593
Mean difference (layer - attention): 0.0065
Std of difference (layer - attention): 0.7091


Layer -1 Transformation Analysis:
Mean of intermediate output: -0.0853
Std of intermediate output: 0.1275
Mean of layer output: -0.0082
Std of layer output: 0.4493
Mean difference (layer - attention): 0.0426
Std of difference (layer - attention): 0.8514


