# Looking into the downstream impacts of L5H1

Context for this: I was looking into the aggregate compensatory response sizes of downstream heads for upstream ones, and it turns out that not only does five one not get backed up, but it can actively make lower heads worse at what they are doing. See image.png.



In [3]:
from imports import *

In [4]:
model_name = "gpt2-small"
model = HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device = device,
)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Get the dataset

In [5]:
owt_dataset = utils.get_dataset("owt")
BATCH_SIZE = 50
PROMPT_LEN = 100

all_owt_tokens = model.to_tokens(owt_dataset[0:BATCH_SIZE * 2]["text"]).to(device)
owt_tokens = all_owt_tokens[0:BATCH_SIZE][:, :PROMPT_LEN]
corrupted_owt_tokens = all_owt_tokens[BATCH_SIZE:BATCH_SIZE * 2][:, :PROMPT_LEN]
assert owt_tokens.shape == corrupted_owt_tokens.shape == (BATCH_SIZE, PROMPT_LEN)

Found cached dataset openwebtext-10k (/data/cody_rushing/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


In [6]:
logits, cache = model.run_with_cache(owt_tokens)

print(utils.lm_accuracy(logits, owt_tokens))
print(utils.lm_cross_entropy_loss(logits, owt_tokens))

tensor(0.3855, device='cuda:0')
tensor(3.3547, device='cuda:0')


In [8]:
def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def residual_stack_to_direct_effect(
    residual_stack: Union[Float[Tensor, "... batch d_model"], Float[Tensor, "... batch pos d_model"]],
    cache: ActivationCache,
    effect_directions: Union[Float[Tensor, "batch d_model"], Float[Tensor, "batch pos d_model"]],
    batch_pos_dmodel = False,
    average_across_batch = True,
    apply_ln = False, # this is broken rn idk
    use_clean_cache_for_ln = True,
    clean_cache = cache
) -> Float[Tensor, "..."]:
    '''
    Gets the avg direct effect between the correct and incorrect answer for a given
    stack of components in the residual stream. Averages across batch by default. In general,
    batch dimension should go in front of pos dimension.

    NOTE: IGNORES THE VERY LAST PREDICTION AND FIRST CLEAN TOKEN; WE DON'T KNOW THE ACTUAL PREDICTED ANSWER FOR IT!

    residual_stack: components of d_model vectors to get direct effect from
    cache: cache of activations from the model
    effect_directions: [batch, d_model] vectors in d_model space that correspond to direct effect
    batch_pos_dmodel: whether the residual stack is in the form [batch, d_model] or [batch, pos, d_model]; if so, returns pos as last dimension
    average_across_batch: whether to average across batch or not; if not, returns batch as last dimension behind pos
    '''
    batch_size = residual_stack.size(-3) if batch_pos_dmodel else residual_stack.size(-2)
    

    if apply_ln:

        cache_to_use = clean_cache if use_clean_cache_for_ln else cache

        if batch_pos_dmodel:
            scaled_residual_stack = cache_to_use.apply_ln_to_stack(residual_stack, layer=-1)
        else:
            scaled_residual_stack = cache_to_use.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    else:
        print("Not applying LN")
        scaled_residual_stack = residual_stack


    # remove the first token from the clean tokens, last token from the predictions - these will now align
    scaled_residual_stack = scaled_residual_stack[:, :, :-1, :]
    effect_directions = effect_directions[:, 1:, :]
    #print(scaled_residual_stack.shape, effect_directions.shape)

    if average_across_batch:
         # average across batch
        if batch_pos_dmodel:
            return einops.einsum(
                scaled_residual_stack, effect_directions,
                "... batch pos d_model, batch pos d_model -> ... pos"
            ) / batch_size
        else:
            return einops.einsum(
                scaled_residual_stack, effect_directions,
                "... batch d_model, batch d_model -> ..."
            ) / batch_size
    else:
        if batch_pos_dmodel:
            return einops.einsum(
                scaled_residual_stack, effect_directions,
                "... batch pos d_model, batch pos d_model -> ... batch pos"
            )
        else:
            return einops.einsum(
                scaled_residual_stack, effect_directions,
                "... batch d_model, batch d_model -> ... batch"
            ) 
            
def collect_direct_effect(cache: ActivationCache, correct_tokens: Float[Tensor, "batch seq_len"],
                           title = "Direct Effect of Heads", display = True) -> Float[Tensor, "heads batch pos"]:
    """
    Given a cache of activations, and a set of correct tokens, returns the direct effect of each head on each token.
    
    returns [heads, batch, pos - 1] length tensor of direct effects of each head on each (correct) token

    cache: cache of activations from the model
    correct_tokens: [batch, seq_len] tensor of correct tokens
    title: title of the plot (relavant if display == True)
    display: whether to display the plot or return the data; if False, returns [head, pos] tensor of direct effects
    """

    
    clean_per_head_residual: Float[Tensor, "head batch seq d_model"] = cache.stack_head_results(layer = -1, return_labels = False, apply_ln = False) 
    token_residual_directions: Float[Tensor, "batch seq_len d_model"] = model.tokens_to_residual_directions(correct_tokens)



    # get the direct effect of heads by positions
    per_head_direct_effect: Float[Tensor, "heads batch pos"] = residual_stack_to_direct_effect(clean_per_head_residual,
                                                                                          cache, token_residual_directions,
                                                                                          batch_pos_dmodel = True, average_across_batch = False,
                                                                                          apply_ln = True)
   
    #assert per_head_direct_effect.shape == (model.cfg.n_heads * model.cfg.n_layers, owt_tokens.shape[0], owt_tokens.shape[1])

    if display:    
        mean_per_head_direct_effect = per_head_direct_effect.mean(dim = (1,2))
        mean_per_head_direct_effect = einops.rearrange(mean_per_head_direct_effect, "(n_layer n_heads_per_layer) -> n_layer n_heads_per_layer",
                                                   n_layer = model.cfg.n_layers, n_heads_per_layer = model.cfg.n_heads)
        fig = imshow(
            torch.stack([mean_per_head_direct_effect]),
            return_fig = True,
            facet_col = 0,
            facet_labels = [f"Direct Effect of Heads"],
            title=title,
            labels={"x": "Head", "y": "Layer", "color": "Logit Contribution"},
            #coloraxis=dict(colorbar_ticksuffix = "%"),
            border=True,
            width=500,
            margin={"r": 100, "l": 100}
        )
        return per_head_direct_effect     
    else:
        return per_head_direct_effect
    
def return_item(item):
  return item

per_head_direct_effect = collect_direct_effect(cache, owt_tokens, display = True)
per_head_direct_effect: Float[Tensor, "n_layer n_head batch pos"] = einops.rearrange(per_head_direct_effect, "(n_layer n_head) batch pos -> n_layer n_head batch pos", n_layer = model.cfg.n_layers, n_head = model.cfg.n_heads)


In [10]:
def dir_effects_from_sample_ablating_head(layer, head):
    """this function gets the new direct effect of all the heads when sample ablating the input head
    it uses the global cache, owt_tokens, corrupted_owt_tokens
    """

    ablate_heads = [[layer, head]]
    new_cache = act_patch(model, owt_tokens, [Node("z", layer, head) for (layer,head) in ablate_heads],
                            return_item, corrupted_owt_tokens, apply_metric_to_cache= True)

    temp = collect_direct_effect(new_cache, owt_tokens, display = False)
    ablated_per_head_batch_direct_effect = einops.rearrange(temp,
                                            "(n_layer n_heads_per_layer) batch pos -> n_layer n_heads_per_layer batch pos",
                                            n_layer = model.cfg.n_layers, n_heads_per_layer = model.cfg.n_heads)
                                            
    return ablated_per_head_batch_direct_effect
def create_scatter_of_backup_of_head(layer, head):
    """"
    this function:
    1) gets the direct effect of all the heads when sample ablating the input head
    2) gets the total accumulated backup of the head for each prompt and position
    3) plots the clean direct effect vs accumulated backup
    """
    ablated_per_head_batch_direct_effect = dir_effects_from_sample_ablating_head(layer, head)

    # 2) gets the total accumulated backup of the head for each prompt and position
    downstream_change_in_logit_diff: Float[Tensor, "layer head batch pos"] = ablated_per_head_batch_direct_effect - per_head_direct_effect
    assert downstream_change_in_logit_diff[0:layer].sum((0,1,2,3)).item() == 0
    sum_across_all_downstream_heads = downstream_change_in_logit_diff[(layer+1):].sum((0,1))
    
    #  3) plots the clean direct effect vs accumulated backup
    fig = go.Figure()
    scatter_plot = go.Scatter(
        x= per_head_direct_effect[layer, head].flatten().cpu(),
        y=sum_across_all_downstream_heads.flatten().cpu(),
        text=[f"Batch {i[0]}, Pos {i[1]}" for i in itertools.product(range(BATCH_SIZE), range(PROMPT_LEN))],  # Set the hover labels to the text attribute
        mode='markers',
        marker=dict(size=2, opacity=0.8),
    )
    fig.add_trace(scatter_plot)
    fig.update_layout(
        title=f"Total Accumulated Backup of {layer}.{head} in {model_name} for each Position and Batch",
    )
    fig.update_xaxes(title = "Direct Effect of Head")
    fig.update_yaxes(title = "Total Accumulated Backup")
    fig.update_layout(width=700, height=400)
    fig.show()



create_scatter_of_backup_of_head(9,6)

In [11]:
def get_slope_of_best_fit_line(layer, head, graph = True):
    """"
    this function:
    1) idea would be to filter for points that are greater than 0.5 direct effect, and then make a line of best fit
    2) plot line and points
    """

    ablated_per_head_batch_direct_effect = dir_effects_from_sample_ablating_head(layer, head)

    # gets the total accumulated backup of the head for each prompt and position
    downstream_change_in_logit_diff: Float[Tensor, "layer head batch pos"] = ablated_per_head_batch_direct_effect - per_head_direct_effect
    assert downstream_change_in_logit_diff[0:layer].sum((0,1,2,3)).item() == 0 # layer 'layer' will not be 0
    sum_across_all_downstream_heads = downstream_change_in_logit_diff[(layer+1):].sum((0,1))
    
    extreme_dir_effects = []
    parallel_backup_amounts = []

    for x, y in zip(per_head_direct_effect[layer, head].flatten().cpu(), sum_across_all_downstream_heads.flatten().cpu()):
        
        extreme_dir_effects.append(x)
        parallel_backup_amounts.append(y)

    # get slope of line of best fit with numpy
    if len(extreme_dir_effects) == 0:
        slope = 0
        intercept = 0
    else:
        # just get a best fit line
        # slope, intercept = np.polyfit(extreme_dir_effects, parallel_backup_amounts, 1)

        # get a best fit line with a constraint that it must go through the origin
        slope, intercept = np.linalg.lstsq(np.vstack([extreme_dir_effects, np.ones(len(extreme_dir_effects))]).T, parallel_backup_amounts, rcond=None)[0]

    if graph:
        fig = go.Figure()
        scatter_plot = go.Scatter(
            x= per_head_direct_effect[layer, head].flatten().cpu(),
            y=sum_across_all_downstream_heads.flatten().cpu(),
            text=[f"Batch {i[0]}, Pos {i[1]}" for i in itertools.product(range(BATCH_SIZE), range(PROMPT_LEN))],  # Set the hover labels to the text attribute
            mode='markers',
            marker=dict(size=2, opacity=0.8),
        )
        fig.add_trace(scatter_plot)

        # get maximum direct effect
        max_x = max(per_head_direct_effect[layer, head].flatten().cpu())

        # add line of best fit
        fig.add_trace(go.Scatter(
            x=torch.linspace(0,max_x,100),
            y=torch.linspace(0,max_x,100) * slope + intercept,
            mode='lines',
            name='lines'
        ))
        fig.update_layout(
            title=f"Total Accumulated Backup of {layer}.{head} in {model_name} for each Position and Batch",
        )
        fig.update_xaxes(title = "Direct Effect of Head")
        fig.update_yaxes(title = "Total Accumulated Backup")
        fig.update_layout(width=700, height=400)
        fig.show()


    return slope, intercept

In [13]:
get_slope_of_best_fit_line(5,1)

(-0.030463708172339534, 0.0011252375062222833)

In [15]:
slopes_of_head_backup = torch.zeros((12,12))
for layer in tqdm(range(model.cfg.n_layers)):
    for head in range(model.cfg.n_heads):
        slopes_of_head_backup[layer, head] = get_slope_of_best_fit_line(layer, head, graph = False)[0]

imshow(slopes_of_head_backup, title = "Slopes of Head Backup",
       text_auto = True, width = 800, height = 800)# show a number above each square)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:30<00:00,  2.58s/it]
