In [1]:
# ! pip install transformers

In [None]:
'''
If you read the implementation of our code for gpt2 and want to see its difference with this code,
or you want to study it to understand how to do the same with other models, please
follow the NOTE-s we left here.
'''

In [2]:
# general imports
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTJForCausalLM
import torch
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import copy

In [3]:
# plotly express imports
import plotly.graph_objects as go
import matplotlib as mpl
import plotly.io as pio

try:
    if 'google.colab' in str(get_ipython()):
        print('running on colab. plot will be presented in notebook')
    else:
        # change to "browser" if you want to see the plots in your browser, else omit this line
        pio.renderers.default = "browser"
except:
    print('Warning: pio.renderers.default = "browser" failed. going to use default renderer')

In [4]:
# another code we wrote
import utils

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
device

device(type='cpu')

In [7]:
# # example 1
# line = 'When Mary and John went to the store, John gave a drink to'
# target_token = ' Mary' # notice the token includes the space before it

In [8]:
# example 2
line = 'The capital of Japan is the city of'
target_token = ' Tokyo' # notice the token includes the space before it

In [9]:
# # example 3 (from Counterfact)
# line = 'Michel Denisot spoke the language of'
# target_token = ' French'

In [10]:
try:
    # delete the model to free up memory (if more than one model is used)
    del model
except:
    pass

model_name = "EleutherAI/gpt-j-6B"
model = GPTJForCausalLM.from_pretrained(
    model_name,
    revision="float16",  # use 16-bit floats to save memory
).to(device)

model.requires_grad_(False)
model_config = model.config

# collect the hidden states before and after each of those layers (modules)
# <-- NOTE: here we adjust the relevent layers to check. we leave the layers we used for gpt2 for comparison
hs_collector = utils.wrap_model(model, layers_to_check = [
    '.mlp', '.mlp.fc_in', '.mlp.fc_out', 
    '.attn', '.attn.q_proj', '.attn.k_proj', '.attn.v_proj', '.attn.out_proj', 
    '.ln_1', '']) # '' stands for wrapping transformer.h.<layer_index> in gpt2 and gpt-j

# hs_collector = utils.wrap_model(model, layers_to_check = [  # gpt2
#     '.mlp', '.mlp.c_fc', '.mlp.c_proj',  
#     '.attn', '.attn.c_attn', '.attn.c_proj', 
#     '.ln_1', '.ln_2', '',])  


# add extra functions to the model (like logit lens adjust to the model decoding matrix)
model_aux = utils.model_extra(model=model, device='cpu')
# NOTE for future developers: model_extra obj copy the final layer norm, ln_f, and the decoding matrix, lm_head.
# luckily, gpt2 and gpt-j share the same names for these layers
# maybe future models will have different names, so we need to check that

tokenizer = AutoTokenizer.from_pretrained(model_name)
try:
    os.environ["TOKENIZERS_PARALLELISM"] = "true"  # not blocking, just to prevent warnings and faster tokenization
except:
    pass
encoded_line = tokenizer.encode(line.rstrip(), return_tensors='pt').to(device)
output_and_cache = model(encoded_line, output_hidden_states=True, output_attentions=True, use_cache=True)

hs_collector['past_key_values'] = output_and_cache.past_key_values  # the "attentnion memory"
hs_collector['attentions'] = output_and_cache.attentions

# the final answer token is:
model_answer = tokenizer.decode(output_and_cache.logits[0, -1, :].argmax().item())
print(f'model_answer: "{model_answer}"')

In [None]:
print(model)

In [None]:
# const definitions
# all the "factor" values are used to scale the weights of the links between the neurons in the different layers
# you are more than welcome to change these values for your own experiments
round_digits = 3

# the numbers of top and bottom neurons to show at the mlp matricies (mlp.c_fc, mpl.c_proj) and the attention projection matrix (W_O, attn.c_proj)
number_of_top_neurons = 10  
number_of_bottom_neurons = 5
defualt_weight = 7
factor_weight_mlp_key_value_link = 1.5
n_values_per_head = 2  # number of ki, vi to show for each head. great to examine with gpt2-small/medium but for gpt2-large/xl you might want to reduce this number
factor_attn_score = 10
factor_head_output = 1.8
factor_head_norm = 0.2
factor_for_memory_to_head_values = 1.3

In [None]:
# color definitions
cmap_node_rank_target = mpl.colors.LinearSegmentedColormap.from_list('cmap_node_rank_target', ['lime', 'greenyellow', 'yellow' , 'yellow', 'yellow'] + ['dimgrey']*53 + ['orangered']*3 + ['red']*4 + ['darkred']*5)
cmap_attn_score = mpl.colors.LinearSegmentedColormap.from_list('cmap_attn_score', ['khaki', 'yellow', 'green'])
cmap_entropy = mpl.colors.LinearSegmentedColormap.from_list('cmap_entropy', ['darkslategrey', 'lightgrey', 'lightgrey'])

backgroud_color = 'black'
invisible_link_color = backgroud_color
color_for_abstract = 'white'
default_color = 'white'
link_with_normalizer = 'darkviolet'
color_for_bias_vector = 'pink'
positive_activation = 'blue'
negative_activation = 'red'

In [None]:
def get_norm_layer(x, debug_msg=None, round_digits=round_digits):
    if type(x) == torch.Tensor:
        res = torch.norm(x).item()
    else:
        res = torch.norm(torch.Tensor(x)).item()
    if debug_msg:
        print(f'{debug_msg}: len(x): {len(x)}, norm: {res}')
    if round_digits > 0:
        return round(res, round_digits)
    return res

In [None]:
def plot_graph_aux(graph_data, title=f'Flow-Graph', save_html=False):
    '''
    A wrapper for graph plotting by plotly express

    @ graph_data: the graph data. see the function @ gen_basic_graph for more details
    @ title: the title of the graph
    @ save_html: if True, the graph will be saved as an html file to {title}.html. if @ save_html is a non empty string, the graph will be saved as an html file to {save_html}.html
    '''
    
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    fig = go.Figure(data=[go.Sankey(
      valueformat = ".0f",
      valuesuffix = "TWh",
      node = dict(
        pad = 15,
        thickness = 15,
        line = dict(color = backgroud_color, width = 0.5),
        label = labels,
        color = colors_nodes,
        customdata = customdata,
        hovertemplate='Node: %{customdata}. %{value}<extra></extra>',
      ),
      link = dict(
        source =  sources,
        target =  targets,
        value =  weights,
        color = colors_links,
        customdata = line_explained,
        hovertemplate='%{source.customdata}<br />' + ' '*50 + '----[%{customdata},  %{value}]----><br />%{target.customdata}<extra></extra>',
      ))]
    )

    fig.update_layout(
        hovermode = 'x',
        title_text = title,
        font=dict(size = 10, color = 'white'),
        plot_bgcolor=backgroud_color,
        paper_bgcolor=backgroud_color
    )

    fig.show()

    # save to html
    if save_html != False:
      path_out = f'{title}.html' if (type(save_html) != str or save_html == '') else f'{save_html}.html'
      fig.write_html(path_out)

In [None]:
ln_f = copy.deepcopy(utils.rgetattr(model, 'transformer.ln_f')).to(device).requires_grad_(False)
decoding_matrix = copy.deepcopy(utils.rgetattr(model, 'lm_head')).to(device).requires_grad_(False)
hidden_d = model.config.n_embd

def logit_status(hs, wanted_idx):
        if wanted_idx == '':
                return -1, -1
        if type(wanted_idx) == str:
                wanted_idx = tokenizer.encode(wanted_idx, add_special_tokens=False)[0]
        if type(hs) == torch.Tensor:
                hs_tensor = hs.clone().detach().to(device)
        else:
                hs_tensor = torch.tensor(hs).to(device)
        
        # logit len including layer normaization with the model final layer norm (ln final)
        specific_logics = ln_f(hs_tensor)
        # hs to vocab scores
        specific_logic_lens = decoding_matrix(specific_logics)
        # scores to probabilities
        specific_logics = torch.nn.functional.softmax(specific_logic_lens, dim=0)

        # get probability of the wanted_idx
        prob = specific_logics[wanted_idx].item()
        # get ranking of the wanted_idx
        smaller = torch.where(specific_logic_lens[wanted_idx] < specific_logic_lens)[0].size()[0]
        ranking = smaller+1
        # rank 1 -> most probable, rank #vocab_size -> least probable
        return prob, ranking

In [None]:
def entropy(probabilities):
    # calculates the entropy of a probability distribution
    if len(probabilities) != model_config.vocab_size:
        probabilities = model_aux.hs_to_probs(probabilities)
    # convert the probabilities to a numpy array
    probabilities = np.array(probabilities.detach())
    # filter out 0 probabilities (to avoid issues with log(0))
    non_zero_probs = probabilities[probabilities != 0]
    entropy = -np.sum(non_zero_probs * np.log2(non_zero_probs))
    return entropy

In [None]:
def get_color_according_to_entropy(hs, max_val=30):
    entropy_score = entropy(hs)
    color_idx = entropy_score / max_val
    color_idx = max(min(color_idx, 1), 0)
    color = f'rgba{cmap_entropy(color_idx)}'
    return color

In [None]:
def get_node_customdata(hs, prefix='', top_or_bottom='top_k', layer_idx=None, target_word=None, color_flag=True):
    '''
    return metadata that uses in the graph plot
    text: the text that will be displayed in the node when hovering over it
    color: the color of the node accodring to its probability of the target word 
    (green if very probable, red if very improbable and grey otherwise)
    '''

    color = default_color 
    hs_meaning = model_aux.hs_to_token_top_k(hs, k_top=5, k_bottom=5)
    res = f'{hs_meaning[top_or_bottom]}'

    if prefix != '':
        res = f'{prefix}: {res}'
    if layer_idx is not None:
        res = f'{layer_idx}) {res}'
    if target_word is not None:
        prob, ranking = logit_status(hs, target_word)
        prob = round(prob*100, round_digits)  # probs [0,1] as percentage [0,100]
        res = f'{res} [status: "{target_word}": prob:{prob}%, rank: {ranking})]'
        if color_flag:
            color_idx = ranking/model_config.vocab_size  # color index [0,1] (according to ranking)
            color = f'rgba{cmap_node_rank_target(color_idx)}'
    return res, color

In [None]:
'''
NOTE:
The main difference between gpt2 and gpt-j is that in gpt2 each block is an attention block followed by a mlp (feed-forward, FF) 
block, however in gpt-j the attention and the mlp block are parrallel to each other. Also:
* where in gpt2 we have a residual for each of the attention and mlp blocks, in gpt-j we have only single 
residual for the whole block. 
* gpt-j adds the position embedding to each block, while gpt-2 does that only to the first block. we did not add it to the graph 
to keep it simple to the gpt2 graph.

except for the differences mentioned above, the structure of the models and their graph is mostly the same.
we use the same code we used for gpt2 except for the following changes:
- adjusting names of the model parameters (for example, in gpt2 mlp matrix called 'c_fc' and 'c_proj' while in gpt-j it is called 'fc_in' and 'fc_out')
- addiing an extra function for the creation of the sub-graphs of the attentnion and mlp sub-blocks: pre_connect_mlp_and_attn_for_gptj
this function create the input for each sub-block and the common output. each sub-block connect to the input and the common output.
this structer helps us to use mostly the same code for both gpt2 and gpt-j.
'''
def pre_connect_mlp_and_attn_for_gptj(layer_idx: int, graph_data, model, hs_collector, row_idx=-1, model_aux=model_aux, target_word=None):
      sources = graph_data['sources']
      targets = graph_data['targets']
      weights = graph_data['weights']
      colors_nodes = graph_data['colors_nodes']
      colors_links = graph_data['colors_links']
      labels = graph_data['labels']
      line_explained = graph_data['line_explained']
      customdata = graph_data['customdata']

      if 'idx_previous_block' in graph_data:
            idx_residual_and_input = graph_data['idx_previous_block']
            residual_and_input_norm = graph_data['previous_block_norm']  
      else:
            block_residual = hs_collector[layer_idx]['ln_1']['input'][row_idx]  # block_residual == previous layer output
            block_residual_meaning = model_aux.hs_to_token_top_k(block_residual, k_top=1, k_bottom=0)
            labels.append(block_residual_meaning['top_k'][0])
            curr_metadata, curr_color = get_node_customdata(block_residual, prefix='mlp_input', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            residual_and_input_norm = get_norm_layer(block_residual)
            idx_residual_and_input = len(labels) - 1

      block_out = hs_collector[layer_idx]['']['output'][row_idx]  # mlp_out + attention_out + attn_out + residual
      block_out_meaning = model_aux.hs_to_token_top_k(block_out, k_top=1, k_bottom=0)
      labels.append(block_out_meaning['top_k'][0])
      curr_metadata, curr_color = get_node_customdata(block_out, prefix='block_output', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
      customdata.append(curr_metadata)
      colors_nodes.append(curr_color)
      idx_curr_block_out = len(labels) - 1
      curr_block_out_norm = get_norm_layer(block_out)

      # connect the residual to the block output
      sources.append(idx_residual_and_input)
      targets.append(idx_curr_block_out)
      weights.append(residual_and_input_norm)
      colors_links.append('rgba(102,0,204,0.3)')  # unique color for the residual
      line_explained.append(f'residual + mlp_out + attn_out, norm: {residual_and_input_norm}')

      # create the input to the attn and mlp block (the norm of the residual)
      block_input = hs_collector[layer_idx]['ln_1']['output'][row_idx]  # block_input == norm of the residual
      block_input_meaning = model_aux.hs_to_token_top_k(block_input, k_top=1, k_bottom=0)
      labels.append(block_input_meaning['top_k'][0])
      curr_metadata, curr_color = get_node_customdata(block_input, prefix='blocks_input (after ln_1)', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
      customdata.append(curr_metadata)
      colors_nodes.append(curr_color)
      idx_curr_blocks_input = len(labels) - 1
      curr_blocks_input_norm = get_norm_layer(block_input)

      # connect the residual to the blocks input
      sources.append(idx_residual_and_input)
      targets.append(idx_curr_blocks_input)
      weights.append(residual_and_input_norm)
      colors_links.append('rgba(102,0,204,0.3)')  # unique color for the residual
      line_explained.append(f'residual into layer norm, norm: {residual_and_input_norm}')

      # save the common input of the sub-blocks and the common outputs
      # so each function for the creation of the sub-blocks can connect to them
      graph_data['idx_previous_block'] = idx_residual_and_input
      graph_data['idx_curr_block_out'] = idx_curr_block_out
      graph_data['curr_block_out_norm'] = curr_block_out_norm

      graph_data['idx_curr_blocks_input'] = idx_curr_blocks_input
      graph_data['curr_blocks_input_norm'] = curr_blocks_input_norm
      
      # we already prepared the input to the next block (for the call of this function with {layer_idx+1})
      graph_data['idx_previous_block'] = idx_curr_block_out
      graph_data['previous_block_norm'] = curr_block_out_norm

In [None]:
def layer_mlp_to_graph(layer_idx: int, graph_data, model, hs_collector, row_idx=-1, model_aux=model_aux, target_word=None):
    '''
    create a subgraph of the feed-forward (FF, MLP) part of the model at layer_idx
    the subgraph is a graph of the neurons in the FF part of the model
    the nodes are the most active neurons in the FF (some of positive and some of negative)
    the links are the connections between the neurons (summation of the neurons or when one neuron creates the coefficient of another neuron)
    the graph is created using the graph_data dictionary
    if the graph_data dictionary's lists are empty, they will be initialized
    if the graph_data dictionary 's lists are not empty, they will be updated (try to connect the new nodes to the existing nodes)

    @ layer_idx: the index of the layer in the model
    @ graph_data: the graph data dictionary (if called first time, should include empty list for the keys it uses)
    @ model: the model (for example: gpt2)
    @ hs_collector: the hs_collector dictionary (created from wrapping the model with the hs_collector class)
    @ row_idx: the index of the row in the hs_collector which correspond to the infrence of the #row_idx token. use -1 for the last token (Note: currently not supported any other value than -1)
    @ model_aux: the model_aux class (more functions for the original model)
    @ target_word: the target word for extracting the status of the neurons (ranking and probability)    
    '''
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    # <-- NOTE: assuming the input and output of this block was just created using pre_connect_mlp_and_attn_for_gptj
    # in gpt-j the mlp and attentnion are parallel and not sequential as in gpt2
    # for this case, those nodes are create separately in the pre_connect_mlp_and_attn_for_gptj function
    # and now we access their indices from the graph_data dictionary
    idx_curr_block_out = graph_data['idx_curr_block_out']
    idx_curr_blocks_input = graph_data['idx_curr_blocks_input']
    curr_blocks_input_norm = graph_data['curr_blocks_input_norm']
    # mlp_residual = hs_collector[layer_idx]['ln_1']['input'][row_idx]  # gpt2. left for reference

    mlp_input = hs_collector[layer_idx]['mlp']['input'][row_idx]  # mlp_input == c_fc_input
    mlp_out = hs_collector[layer_idx]['mlp']['output'][row_idx]  # only the output of the mlp

    mlp_input_meaning = model_aux.hs_to_token_top_k(mlp_input, k_top=1, k_bottom=0)
    mlp_out_meaning = model_aux.hs_to_token_top_k(mlp_out, k_top=1, k_bottom=0)

    labels.append(mlp_input_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(mlp_input, prefix='mlp_input', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_mlp_input = len(labels) - 1

    labels.append(mlp_out_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(mlp_out, prefix='mlp_out', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_mlp_out = len(labels) - 1

    sources.append(idx_curr_blocks_input)
    targets.append(idx_mlp_input)
    weights.append(curr_blocks_input_norm)
    colors_links.append(link_with_normalizer)
    line_explained.append(f'mlp input after norm:{curr_blocks_input_norm}')

    sources.append(idx_mlp_out)
    targets.append(idx_curr_block_out)  # TODO
    curr_norm = get_norm_layer(mlp_out)
    weights.append(curr_norm)
    colors_links.append(get_color_according_to_entropy(mlp_out))
    line_explained.append(f'norm:{curr_norm}')

    # get the first and second matricies of the mlp
    c_fc = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.fc_in.weight").clone().detach().cpu().T  # <-- NOTE: adjust this layer name
    c_proj = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.fc_out.weight").clone().detach().cpu().T  # <-- NOTE: adjust this layer name
    
    # NOTE: for better understanding, we left the code we used for GPT-2
    # c_fc = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.c_fc.weight").clone().detach().cpu()
    # c_proj = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.c_proj.weight").clone().detach().cpu()

    # you can change this according to your needs (like screen size)

    values_norm = c_proj.norm(dim=1)
    hs =  hs_collector[layer_idx]['mlp.fc_out']['input'][row_idx]  # value activation. mid results between the key and value matrix
    hs_mul_norm = hs * values_norm  # this is our metric for the importance of the value activation (value*norm of the second matrix)
    
    # pick the top most activate neurons (according to activasion sign)
    for case, n_top, is_largest in [('top_k', number_of_top_neurons, True), ('bottom_k', number_of_bottom_neurons, False)]:
        tops = torch.topk(hs_mul_norm, k=n_top, largest=is_largest)
        for entry_idx, activision_value_mul_norm in zip(tops.indices, tops.values):
            activision_value_mul_norm = round(activision_value_mul_norm.item(), round_digits)
            activision_value = round(hs[entry_idx].item(), round_digits)
            entry_idx = entry_idx.item()

            # create a node for the "key" neuron (the first matrix)
            idx_key = len(labels)
            idx_value = idx_key+1
            curr_c_fc_meaning = model_aux.hs_to_token_top_k(c_fc.T[entry_idx], k_top=1, k_bottom=0)
            labels.append(curr_c_fc_meaning['top_k'][0])
            curr_metadata, curr_color = get_node_customdata(c_fc.T[entry_idx], prefix=f'key{entry_idx}', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)

            # create a node for the "value" neuron (the second matrix)
            curr_c_proj_meaning = model_aux.hs_to_token_top_k(c_proj[entry_idx], k_top=1, k_bottom=1)
            labels.append(curr_c_proj_meaning[case][0])  # value is chosen accodring to activation sign. suppouse to reflect the meaning its adding to the output
            curr_metadata, curr_color = get_node_customdata(c_proj[entry_idx], prefix=f'value{entry_idx}', top_or_bottom=case, layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)

            # connect the "key" and the "value"
            sources.append(idx_key)
            targets.append(idx_value)
            weights.append(abs(activision_value)*factor_weight_mlp_key_value_link)
            colors_links.append(positive_activation if is_largest else negative_activation)
            line_explained.append(f'activision:{activision_value}')

            # add between the mlp_input and the "key" neuron
            sources.append(idx_mlp_input)
            targets.append(idx_key)
            weights.append(defualt_weight)
            colors_links.append(get_color_according_to_entropy(c_fc.T[entry_idx]))
            line_explained.append(f'')

            # add a link between of the "value" neuron to the mlp_out
            sources.append(idx_value)
            targets.append(idx_mlp_out)
            weights.append(abs(activision_value_mul_norm))
            colors_links.append(get_color_according_to_entropy(c_proj[entry_idx]))
            line_explained.append(f'activision*norm:{abs(activision_value_mul_norm)}')
    
    # we also add neurons representing the matricies bias vectors
    # we create nodes for each matricies bias vector then connect them to the flow
    c_fc_bias = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.fc_in.bias").clone().detach().cpu() @ c_proj  # NOTE <-- the first mlp matrix bias name
    c_fc_bias_meaning = model_aux.hs_to_token_top_k(c_fc_bias, k_top=1, k_bottom=0)
    labels.append(c_fc_bias_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(c_fc_bias, prefix='fc_out_bias', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_bias_vector) # special color for bias
    idx_c_fc_bias = len(labels) - 1

    c_proj_bias = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.fc_out.bias").clone().detach().cpu()  # NOTE <-- the second mlp matrix bias name
    c_proj_bias_meaning = model_aux.hs_to_token_top_k(c_proj_bias, k_top=1, k_bottom=0)
    labels.append(c_proj_bias_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(c_proj_bias, prefix='fc_out_bias', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_bias_vector)
    idx_c_proj_bias = len(labels) - 1

    # connect mlp_input to c_fc_bias
    sources.append(idx_mlp_input)
    targets.append(idx_c_fc_bias)
    curr_norm = get_norm_layer(mlp_input)
    weights.append(defualt_weight)
    colors_links.append(color_for_bias_vector)
    line_explained.append(f'norm:{curr_norm}')

    # connect c_fc_bias to c_proj_bias, although it is not really a link (this why its color is invisible_link_color)
    sources.append(idx_c_fc_bias)
    targets.append(idx_c_proj_bias)
    weights.append(0.05) # to be barely visible (trick to make it unvisible with the background)
    colors_links.append(invisible_link_color)
    line_explained.append(f'norm:{curr_norm}')

    # connect c_proj_bias to mlp_out
    sources.append(idx_c_proj_bias)
    targets.append(idx_mlp_out)
    curr_norm = get_norm_layer(c_proj_bias)
    weights.append(curr_norm)
    colors_links.append(color_for_bias_vector)
    line_explained.append(f'norm:{curr_norm}')


In [None]:
def layer_attn_to_graph(layer_idx, graph_data, hs_collector, model, row_idx=-1, model_aux=model_aux, target_word=None, line=None):
    '''
    create a graph for the attention (attn) layer
    the subgraph is a graph of the neurons in the Q, K, V O matricies, mostly aggregated into heads
    the nodes are single or small groups of neurons (when they are aggregated into heads or subheads)
    the links are the connections between the neurons (summation of the neurons or when one neuron creates the coefficient of another neuron)
    the graph is created using the graph_data dictionary
    if the graph_data dictionary's lists are empty, they will be initialized
    if the graph_data dictionary 's lists are not empty, they will be updated (try to connect the new nodes to the existing nodes)

    @ layer_idx: the index of the layer in the model
    @ graph_data: the graph data dictionary (if called first time, should include empty list for the keys it uses)
    @ model: the model (for example: gpt2)
    @ hs_collector: the hs_collector dictionary (created from wrapping the model with the hs_collector class)
    @ row_idx: the index of the row in the hs_collector which correspond to the infrence of the #row_idx token. use -1 for the last token (Note: currently not supported any other value than -1)
    @ model_aux: the model_aux class (more functions for the original model)
    @ target_word: the target word for extracting the status of the neurons (ranking and probability)
    @ line: the line that was used to generate the data in hs_collector (the prompt to the model)
    '''
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    # <-- NOTE: assuming the input and output of this block was just created using pre_connect_mlp_and_attn_for_gptj
    idx_curr_block_out = graph_data['idx_curr_block_out']
    idx_curr_blocks_input = graph_data['idx_curr_blocks_input']
    curr_blocks_input_norm = graph_data['curr_blocks_input_norm']

    # uses to show what was the token that generated the attention memory (previous keys and layer)
    # the i-th key and i-th value were generated by the i-th token
    # if @line is not given (None) - will not show this information
    parsed_line = None
    if line is not None:
        parsed_line = tokenizer.encode(line, return_tensors='pt')
        # save the parsed line for later use
        parsed_line = [tokenizer.decode(x.item()) for x in parsed_line[0]]


    # create nodes for the attention input (after layer norm) and for the attention output
    attn_input = hs_collector[layer_idx]['attn']['input'][row_idx]
    attn_input_meaning = model_aux.hs_to_token_top_k(attn_input, k_top=1, k_bottom=0)
    labels.append(attn_input_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(attn_input, prefix=f'attn_input', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_attn_input = len(labels) - 1

    # attn output
    attn_out = hs_collector[layer_idx]['attn']['output'][row_idx]
    attn_out_meaning = model_aux.hs_to_token_top_k(attn_out, k_top=1, k_bottom=0)
    labels.append(attn_out_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(attn_out, prefix=f'attn_out', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_attn_out = len(labels) - 1

    # link between the common input and attn_input
    sources.append(idx_curr_blocks_input)
    targets.append(idx_attn_input)
    weights.append(curr_blocks_input_norm) 
    colors_links.append(link_with_normalizer)
    line_explained.append(f'attn input after norm:{curr_blocks_input_norm}')

    # link between attn_output and idx_curr_block_out
    sources.append(idx_attn_out)
    targets.append(idx_curr_block_out)
    curr_weight = get_norm_layer(attn_out)
    weights.append(curr_weight)
    colors_links.append(get_color_according_to_entropy(attn_out))
    line_explained.append(f'norm:{curr_weight}')

    # <-- NOTE: adjust the access to the attentnion matrices 
    Wq = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.q_proj.weight").clone().detach().cpu()
    Wk = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.k_proj.weight").clone().detach().cpu()
    Wv = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.k_proj.weight").clone().detach().cpu()
    c_proj = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.out_proj.weight").clone().detach().cpu().T

    # NOTE: we leave the code for the original gpt2 model (there, Q,K,V matrices are actually concatenated into one matrix)
    # c_attn = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.c_attn.weight").clone().detach().cpu() # W_QKV (the QKV matrix)
    # c_proj = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.c_proj.weight").clone().detach().cpu() # W_O (the Output/projection matrix)
    # # in gpt2 c_attn is the concatenation of Wq, Wk, Wv (Q, K, V)
    # Wq = c_attn[:, :hidden_d]
    # Wk = c_attn[:, hidden_d:2*hidden_d]
    # Wv = c_attn[:, 2*hidden_d:]

    # <-- NOTE: adjust the access to the hidden states
    # NOTE: we leave the code for the original gpt2 model under each of the q,k,v hidden states
    q = hs_collector[layer_idx]['attn.q_proj']['output'][row_idx].cpu()  # <-- NOTE (this layer query)
    # q = c_attn_output[ :hidden_d]  # gpt2
    k = hs_collector[layer_idx]['attn.k_proj']['output'][row_idx].cpu()  # <-- NOTE (this layer key. it is added to the attention memory (to "past_key_values" so also the next tokens can use it) )
    # k = c_attn_output[hidden_d:2*hidden_d]  # gpt2
    v = hs_collector[layer_idx]['attn.v_proj']['output'][row_idx].cpu()  # <-- NOTE (this layer value. like the key, it is added to the attention memory )
    # v = c_attn_output[2*hidden_d:] # gpt2

    # projection using the QK circuit
    def pre_project_q(hs_q):
        return hs_q @ Wk
        # return Wk @ hs_q

    # projection using the QK circuit
    def pre_project_k(hs_k):
        # return hs_k @ Wq
        return Wq @ hs_k
    
    # projection using the OV circuit
    def pre_project_v(hs_v):
        # return hs_v @ c_proj
        return hs_v @ c_proj
    
    q_projected = pre_project_q(q)
    k_projected = pre_project_k(k)
    v_poject = pre_project_v(v)  

    # create a node for q,k,v together
    q_meaning = model_aux.hs_to_token_top_k(q_projected, k_top=1, k_bottom=0)
    q_data, curr_color = get_node_customdata(q_projected, prefix=f'q (for current calc)', top_or_bottom='top_k', target_word=target_word)
    k_data, _ = get_node_customdata(k_projected, prefix=f'k (for next tokens)', top_or_bottom='top_k', target_word=target_word)
    v_data, _ = get_node_customdata(v_poject, prefix=f'v (for next tokens)', top_or_bottom='top_k', target_word=target_word)

    curr_metadata = f'q,k,v (before splitting into heads):' + '<br />' + q_data + '<br />' + k_data + '<br />' + v_data

    # create a new node for query (q) and add the key and value (k, v) metadata. we call this node qkv_full
    labels.append(q_meaning['top_k'][0])
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_qkv_full = len(labels) - 1

    # connect between attn_input and qkv_full   
    sources.append(idx_attn_input)
    targets.append(idx_qkv_full)
    curr_weight = get_norm_layer(attn_input)
    weights.append(curr_weight)
    colors_links.append(get_color_according_to_entropy(q))
    line_explained.append(f'norm:{curr_weight}')

    # the concated results from all the heads but without the OV circuit projection
    concated_heads_wihtout_projection = hs_collector[layer_idx]['attn.out_proj']['input'][row_idx]  # <-- NOTE: adjust layer name
    # concated_heads_wihtout_projection = hs_collector[layer_idx]['attn.c_proj']['input'][row_idx] # gpt2

    idx_concated_heads = len(labels)  # attn_c_proj input
    concated_heads_without_projection_meaning = model_aux.hs_to_token_top_k(concated_heads_wihtout_projection, k_top=1, k_bottom=0)
    labels.append(concated_heads_without_projection_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(concated_heads_wihtout_projection, prefix=f'concated_heads (without projection)', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_abstract)

    # print the hidden meaning of each of the heads
    dim_head = model.config.n_embd // model.config.n_head

    # create for each head the following nodes:
    # (1) the qi - this head part in the query q (we also add its information about ki, vi that were generated at this layer and saved to the attention memory)
    # (2) its #n_values_per_head top biggest ki and vi (the keys and values from the attention memory) accodring to the attention score
    # (3) the head output - the weighted summation of all the vi into it 
    for head_idx in range(model.config.n_head):
        # hs_collector['past_key_values'][layer_idx][0 for key, 1 for value][entry in batch][head_idx] -> list of the keys/values for this head. the i-th entry is the key/value for the i-th token in the input
        keys = hs_collector['past_key_values'][layer_idx][0][0][head_idx]  
        values = hs_collector['past_key_values'][layer_idx][1][0][head_idx]
        attentions = hs_collector['attentions'][layer_idx][0][head_idx][row_idx]

        # qi with the information about ki, vi (1)
        qi = q[dim_head * head_idx: dim_head * (head_idx + 1)]
        ki = k[dim_head * head_idx: dim_head * (head_idx + 1)]
        vi = v[dim_head * head_idx: dim_head * (head_idx + 1)]

        q_i_fill = torch.zeros(model.config.n_embd)
        q_i_fill[dim_head*head_idx:dim_head*(head_idx+1)] = qi
        q_i_projected = pre_project_q(q_i_fill)

        k_i_fill = torch.zeros(model.config.n_embd)
        k_i_fill[dim_head*head_idx:dim_head*(head_idx+1)] = ki
        k_i_projected = pre_project_k(k_i_fill)

        v_i_fill = torch.zeros(model.config.n_embd)
        v_i_fill[dim_head*head_idx:dim_head*(head_idx+1)] = vi
        v_i_projected = pre_project_v(v_i_fill) 

        q_i_projected_meaning = model_aux.hs_to_token_top_k(q_i_projected, k_top=1, k_bottom=0)
        q_data, curr_color = get_node_customdata(q_i_projected, prefix=f'qi (for current calc)', top_or_bottom='top_k', target_word=target_word)
        k_data, _ = get_node_customdata(k_i_projected, prefix=f'ki (for next tokens)', top_or_bottom='top_k', target_word=target_word)
        v_data, _ = get_node_customdata(v_i_projected, prefix=f'vi (for next tokens)', top_or_bottom='top_k', target_word=target_word)

        curr_metadata = f'head {head_idx}:' + '<br />' + q_data + '<br />' + k_data + '<br />' + v_data

        # create a new node for query and add the key and value metadata
        labels.append(q_i_projected_meaning['top_k'][0])
        customdata.append(curr_metadata)
        colors_nodes.append(curr_color)
        idx_q_i_projected = len(labels) - 1

        # create a link between the input and the head query 
        sources.append(idx_qkv_full)
        targets.append(idx_q_i_projected)
        curr_weight = get_norm_layer(q_i_projected)
        weights.append(curr_weight*factor_head_norm)
        colors_links.append(get_color_according_to_entropy(q_i_projected))
        line_explained.append(f'norm: {curr_weight}')

        # create a new node for head output (3)
        head_output = torch.zeros(model.config.n_embd)
        head_output[dim_head*head_idx:dim_head*(head_idx+1)] = concated_heads_wihtout_projection[dim_head*head_idx:dim_head*(head_idx+1)]
        head_output_projected = pre_project_v(head_output)
        head_output_projected_meaning = model_aux.hs_to_token_top_k(head_output_projected, k_top=1, k_bottom=0)

        labels.append(head_output_projected_meaning['top_k'][0])
        curr_metadata, curr_color = get_node_customdata(head_output_projected, prefix=f'head {head_idx}: output', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
        customdata.append(curr_metadata)
        colors_nodes.append(curr_color)
        idx_head_output_projected = len(labels) - 1

        # create a link between the head output and the concated heads
        sources.append(idx_head_output_projected)
        targets.append(idx_concated_heads)
        curr_weight = get_norm_layer(head_output_projected)
        weights.append(curr_weight*factor_head_output)
        colors_links.append(get_color_according_to_entropy(head_output_projected))
        line_explained.append(f'norm after projection:{curr_weight}')
        
        # the nodes representing the #n_values_per_head top biggest ki and vi (2)
        best_head_vals = attentions.topk(n_values_per_head, dim=0)
        for attn_val_idx, attn_score in zip(best_head_vals.indices, best_head_vals.values):
            attn_score = round(attn_score.item(), round_digits)
            keys_from_attn = keys[attn_val_idx]  # should be in the size of the subhead (for example, 64 for gpt2-medium)
            values_from_attn = values[attn_val_idx]  # should be in the size of the subhead

            keys_from_attn_proj = torch.zeros(model.config.n_embd)
            keys_from_attn_proj[dim_head*head_idx:dim_head*(head_idx+1)] = keys_from_attn
            keys_from_attn_proj = pre_project_k(keys_from_attn_proj)

            values_from_attn_proj = torch.zeros(model.config.n_embd)
            values_from_attn_proj[dim_head*head_idx:dim_head*(head_idx+1)] = values_from_attn
            values_from_attn_proj = pre_project_v(values_from_attn_proj)

            keys_from_attn_proj_meaning = model_aux.hs_to_token_top_k(keys_from_attn_proj, k_top=1, k_bottom=0)
            values_from_attn_proj_meaning = model_aux.hs_to_token_top_k(values_from_attn_proj, k_top=1, k_bottom=0)

            # create node for the top keys ki
            labels.append(keys_from_attn_proj_meaning['top_k'][0])
            if parsed_line is not None:
                curr_metadata, curr_color = get_node_customdata(keys_from_attn_proj, prefix=f'head {head_idx}: key {attn_val_idx} [created from "{parsed_line[attn_val_idx]}"]', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            else:
                curr_metadata, curr_color = get_node_customdata(keys_from_attn_proj, prefix=f'head {head_idx}: key {attn_val_idx}', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            idx_keys_from_attn_proj = len(labels) - 1

            # create node for the top values vi
            labels.append(values_from_attn_proj_meaning['top_k'][0])
            if parsed_line is not None:
                curr_metadata, curr_color = get_node_customdata(values_from_attn_proj, prefix=f'head {head_idx}: value {attn_val_idx} [created from "{parsed_line[attn_val_idx]}"]', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            else:
                curr_metadata, curr_color = get_node_customdata(values_from_attn_proj, prefix=f'head {head_idx}: value {attn_val_idx}', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            idx_values_from_attn_proj = len(labels) - 1

            # create a link between the query qi and the top keys ki
            color_attn = f'rgba{cmap_attn_score(attn_score)}'

            sources.append(idx_q_i_projected)
            targets.append(idx_keys_from_attn_proj)
            curr_weight = get_norm_layer(q_i_projected)
            weights.append(max(attn_score*factor_attn_score, 0.51))  # in case the attentnion score is to low and the line might not be visible
            colors_links.append(color_attn)
            line_explained.append(f'attention score: {attn_score} (qi norm: {curr_weight})')

            # create the link between the top keys and the top values
            sources.append(idx_keys_from_attn_proj)
            targets.append(idx_values_from_attn_proj)
            curr_weight = get_norm_layer(keys_from_attn_proj)
            weights.append(max(attn_score*factor_attn_score, 0.51))
            colors_links.append(color_attn)
            line_explained.append(f'attention score: {attn_score} (ki norm: {curr_weight})')

            # create a link between the top values and the idx_head_output_projected
            sources.append(idx_values_from_attn_proj)
            targets.append(idx_head_output_projected)
            curr_weight = get_norm_layer(values_from_attn_proj)
            weights.append(max(attn_score*curr_weight*factor_for_memory_to_head_values, 0.51))
            colors_links.append(get_color_according_to_entropy(values_from_attn_proj))
            line_explained.append(f'attention score * norm: {attn_score*curr_weight} (vi norm: {curr_weight})')
    
    # now we want to present single neurons from the concatenated heads and how they are projected (indevideually) to the output by W_O (the output projection matrix)
    # we pick only the top most activated neurons (positive and negative)
    concated_heads_wihtout_projection_mul_norm = concated_heads_wihtout_projection * c_proj.norm(dim=1)
    for case, n_top, is_largest in [('top_k', number_of_top_neurons, True), ('bottom_k', number_of_bottom_neurons, False)]:
        tops = torch.topk(concated_heads_wihtout_projection_mul_norm, k=n_top, largest=is_largest)
        for entry_idx, activision_mul_norm in zip(tops.indices, tops.values):
            activision_mul_norm = round(activision_mul_norm.item(), round_digits)
            entry_idx = entry_idx.item()
            activision_value = round(concated_heads_wihtout_projection[entry_idx].item(), round_digits)
                
            # connect between c_proj_input, which is the concatenated heads, and each of this neurons
            idx_value = len(labels)
            sources.append(idx_concated_heads)
            targets.append(idx_value)
            curr_weight = defualt_weight
            weights.append(curr_weight)
            colors_links.append(positive_activation if is_largest > 0 else negative_activation)
            curr_c_proj_meaning = model_aux.hs_to_token_top_k(c_proj[entry_idx], k_top=1, k_bottom=1)
            labels.append(curr_c_proj_meaning[case][0])
            curr_metadata, curr_color = get_node_customdata(c_proj[entry_idx], prefix=f'value index:{entry_idx} (from head {entry_idx//(model.config.n_embd//model.config.n_head)}), activision:{activision_value}: ', 
                            top_or_bottom=case, layer_idx=layer_idx, target_word=target_word)  # value is chosen accodring to activation sign. suppouse to reflect the meaning its adding to the output
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            line_explained.append(f'activision:{activision_value}')

            # add a link between of the last "targets" to the attn_output
            sources.append(idx_value)
            targets.append(idx_attn_out)
            weights.append(abs(activision_mul_norm))
            colors_links.append(get_color_according_to_entropy(c_proj[entry_idx]))
            line_explained.append(f'activision*norm:{activision_mul_norm}')

    # NOTE: in gpt2 we added the projection matrix bias (attn.c_proj.bias), however, gpt-j attention matrices
    # don't have bias, so we skip this part of the code here

In [None]:
def gen_basic_graph_gpt_j(layers, hs_collector, model, model_aux=model_aux, target_word=None, line=None, save_html=False):
    '''
    A wrapper function to generate a graph for a given layers
    The main different of this function from its gpt2 implementation is that it uses the function pre_connect_mlp_and_attn_for_gptj
    before calling the creation of the attention and mlp subgraphs

    @layers: a list of layers to generate the graph for (correctness is guaranteed only if layers are in order)
    @ hs_collector: the hs_collector dictionary (created from wrapping the model with the hs_collector class)
    @ model: the model (for example: gpt2)
    @ model_aux: the model_aux class (more functions for the original model)
    @ target_word: the target word for extracting the status of the neurons (ranking and probability)
    @ line: the line that was used to generate the data in hs_collector (the prompt to the model)
    @ save_html: if True, the graph will be saved as an html file to {title}.html. if @ save_html is a non empty string, the graph will be saved as an html file to {save_html}.html
    '''

    model = model.cpu() # all the calculations are done on the cpu (also assuming that hs_collector is on the cpu)

    # init the graph data
    graph_data = {
        'sources': [],
        'targets': [],
        'weights': [],
        'colors_nodes': [],
        'colors_links': [],
        'labels': [],
        'line_explained': [],
        'customdata': []
    }

    if type(layers) != list:
        layers = [layers]
    
    # generate the graph for each layer
    # each layer is a block of two sub-blocks: the attention block and the mlp block
    for layer_idx in layers:
        pre_connect_mlp_and_attn_for_gptj(layer_idx, graph_data, hs_collector=hs_collector, model=model, model_aux=model_aux, target_word=target_word)
        layer_attn_to_graph(layer_idx, graph_data, hs_collector=hs_collector, model=model, model_aux=model_aux, target_word=target_word, line=line)
        layer_mlp_to_graph(layer_idx, graph_data, hs_collector=hs_collector, model=model, model_aux=model_aux, target_word=target_word)
    
    plot_graph_aux(graph_data, title=f'Flow-Grpah of layers {layers}--> propt: "{line}". target: "{target_word}"', save_html=save_html)

In [None]:
# example of usage
gen_basic_graph_gpt_j([layer_index for layer_index in range(model_config.n_layer-3, model_config.n_layer)], model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html=True)

In [None]:
# # more examples
gen_basic_graph_gpt_j([0], model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html=True)
gen_basic_graph_gpt_j(8, model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html=True)
# gen_basic_graph_gpt_j([9, 10], model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html=True)
# gen_basic_graph_gpt_j([11], model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html=True)