In [None]:
# ! pip install transformers

In [None]:
# general imports
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import copy

In [None]:
# plotly express imports
import plotly.graph_objects as go
import matplotlib as mpl
import plotly.io as pio

try:
    if 'google.colab' in str(get_ipython()):
        print('running on colab. plot will be presented in notebook')
    else:
        # change to "browser" if you want to see the plots in your browser, else omit this line
        pio.renderers.default = "browser"
except:
    print('Warning: pio.renderers.default = "browser" failed. going to use default renderer')

In [None]:
# another code we wrote
import utils

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

device

In [None]:
# example 1
model_name = 'gpt2' # gpt2-small
line = 'When Mary and John went to the store, John gave a drink to'
target_token = ' Mary' # notice the token includes the space before it

In [None]:
# # example 2
# model_name = 'gpt2-medium'
# line = 'The capital of Japan is the city of'
# target_token = ' Tokyo' # notice the token includes the space before it

In [None]:
# # example 3 (from Counterfact)
# model_name = 'gpt2-medium'
# line = 'Michel Denisot spoke the language of'
# target_token = ' French'

In [None]:
try:
    # delete the model to free up memory (if more than one model is used)
    del model
except:
    pass

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.requires_grad_(False)
model_config = model.config

# collect the hidden states before and after each of those layers (modules)
hs_collector = utils.wrap_model(model, layers_to_check = [
    '.mlp', '.mlp.c_fc', '.mlp.c_proj',  
    '.attn', '.attn.c_attn', '.attn.c_proj', 
    '.ln_1', '.ln_2', '',])  # '' stands for wrapping transformer.h.<layer_index> in gpt2

# add extra functions to the model (like logit lens adjust to the model decoding matrix)
model_aux = utils.model_extra(model=model, device='cpu')

tokenizer = AutoTokenizer.from_pretrained(model_name)
try:
    os.environ["TOKENIZERS_PARALLELISM"] = "true"  # not blocking, just to prevent warnings and faster tokenization
except:
    pass
encoded_line = tokenizer.encode(line.rstrip(), return_tensors='pt').to(device)
output_and_cache = model(encoded_line, output_hidden_states=True, output_attentions=True, use_cache=True)

hs_collector['past_key_values'] = output_and_cache.past_key_values  # the "attentnion memory"
hs_collector['attentions'] = output_and_cache.attentions

# the final answer token is:
model_answer = tokenizer.decode(output_and_cache.logits[0, -1, :].argmax().item())
print(f'model_answer: "{model_answer}"')

In [None]:
# const definitions
# all the "factor" values are used to scale the weights of the links between the neurons in the different layers
# you are more than welcome to change these values for your own experiments
round_digits = 3

# the numbers of top and bottom neurons to show at the mlp matricies (mlp.c_fc, mpl.c_proj) and the attention projection matrix (W_O, attn.c_proj)
number_of_top_neurons = 20  
number_of_bottom_neurons = 10
defualt_weight = 7
factor_weight_mlp_key_value_link = 1.5
n_values_per_head = 2  # number of ki, vi to show for each head. great to examine with gpt2-small/medium but for gpt2-large/xl you might want to reduce this number
factor_attn_score = 10
factor_head_output = 1.8
factor_head_norm = 0.2
factor_for_memory_to_head_values = 1.3

# if to merge those types of nodes into one node to save space, since they are having only input and output ranks of 1 and used as direct mapping between each other
compact_mlp_nodes = False
compact_attn_k_v_nodes = False

In [None]:
# color definitions
cmap_node_rank_target = mpl.colors.LinearSegmentedColormap.from_list('cmap_node_rank_target', ['lime', 'greenyellow', 'yellow' , 'yellow', 'yellow'] + ['dimgrey']*53 + ['orangered']*3 + ['red']*4 + ['darkred']*5)
cmap_attn_score = mpl.colors.LinearSegmentedColormap.from_list('cmap_attn_score', ['khaki', 'yellow', 'green'])
cmap_entropy = mpl.colors.LinearSegmentedColormap.from_list('cmap_entropy', ['darkslategrey', 'lightgrey', 'lightgrey'])

backgroud_color = 'black'
invisible_link_color = backgroud_color
color_for_abstract = 'white'
default_color = 'white'
link_with_normalizer = 'darkviolet'
color_for_bias_vector = 'pink'
positive_activation = 'blue'
negative_activation = 'red'

In [None]:
def get_norm_layer(x, debug_msg=None, round_digits=round_digits):
    if type(x) == torch.Tensor:
        res = torch.norm(x).item()
    else:
        res = torch.norm(torch.Tensor(x)).item()
    if debug_msg:
        print(f'{debug_msg}: len(x): {len(x)}, norm: {res}')
    if round_digits > 0:
        return round(res, round_digits)
    return res

In [None]:
def merge_two_nodes(graph_data, index1, index2, prefix1='', prefix2=''):
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    if index1 +1 != index2 or index2 + 1 != len(labels):
        raise Exception(f'The call for merge_two_nodes is not valid. it should be done only if the last two nodes are the ones to merge, \
                         and no other nodes were added after them. got index1: {index1}, index2: {index2}, len(labels): {len(labels)}')

    print(f'Start merge_two_nodes for index1: {index1}, index2: {index2}')

    merged_label = f'{labels[index1]} > {labels[index2]}'
    merged_color = colors_nodes[index2]
    merged_customdata = f'{prefix1}{customdata[index1]}<br />{prefix2}{customdata[index2]}'

    # pop old nodes and create new one
    for _ in range(2):
        labels.pop()
        colors_nodes.pop()
        customdata.pop()
    
    labels.append(merged_label)
    colors_nodes.append(merged_color)
    customdata.append(merged_customdata)
    new_index = len(labels) - 1

    # find the common links and remove them
    i = 0
    while i < len(sources):
        if (sources[i] == index1 and targets[i] == index2) or (sources[i] == index2 and targets[i] == index1):
            weights.pop(i)
            colors_links.pop(i)
            sources.pop(i)
            targets.pop(i)
            line_explained.pop(i)
            break
        i += 1

    # find all links with one of the old nodes and change them to the new node
    for old_index in [index1, index2]:
        i = 0
        while i < len(sources):  # len of sources and targets are the same
            if sources[i] == old_index:
                sources[i] = new_index
            if targets[i] == old_index:
                targets[i] = new_index
            i += 1
    
    # print(f'Finish merge_two_nodes for index1: {index1}, index2: {index2}')

    return new_index

In [None]:
def plot_graph_aux(graph_data, title=f'Flow-Graph', save_html=False):
    '''
    A wrapper for graph plotting by plotly express

    @ graph_data: the graph data. see the function @ gen_basic_graph for more details
    @ title: the title of the graph
    @ save_html: if True, the graph will be saved as an html file to {title}.html. if @ save_html is a non empty string, the graph will be saved as an html file to {save_html}.html
    '''
    
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    fig = go.Figure(data=[go.Sankey(
      valueformat = ".0f",
      valuesuffix = "TWh",
      node = dict(
        pad = 15,
        thickness = 15,
        line = dict(color = backgroud_color, width = 0.5),
        label = labels,
        color = colors_nodes,
        customdata = customdata,
        hovertemplate='Node: %{customdata}. %{value}<extra></extra>',
      ),
      link = dict(
        source =  sources,
        target =  targets,
        value =  weights,
        color = colors_links,
        customdata = line_explained,
        hovertemplate='%{source.customdata}<br />' + ' '*50 + '----[%{customdata},  %{value}]----><br />%{target.customdata}<extra></extra>',
      ))]
    )

    fig.update_layout(
        hovermode = 'x',
        hoverlabel=dict(font_size=16),
        title_text = title,
        font=dict(size=12, color='white'),
        plot_bgcolor=backgroud_color,
        paper_bgcolor=backgroud_color
    )

    fig.show()

    # save to html
    if save_html != False:
      path_out = f'{title}.html' if (type(save_html) != str or save_html == '') else f'{save_html}.html'
      fig.write_html(path_out)

In [None]:
ln_f = copy.deepcopy(utils.rgetattr(model, 'transformer.ln_f')).to(device).requires_grad_(False)
decoding_matrix = copy.deepcopy(utils.rgetattr(model, 'lm_head')).to(device).requires_grad_(False)
hidden_d = model.config.n_embd

def logit_status(hs, wanted_idx):
        if wanted_idx == '':
                return -1, -1
        if type(wanted_idx) == str:
                wanted_idx = tokenizer.encode(wanted_idx, add_special_tokens=False)[0]
        if type(hs) == torch.Tensor:
                hs_tensor = hs.clone().detach().to(device)
        else:
                hs_tensor = torch.tensor(hs).to(device)
        
        # logit len including layer normaization with the model final layer norm (ln final)
        specific_logics = ln_f(hs_tensor)
        # hs to vocab scores
        specific_logic_lens = decoding_matrix(specific_logics)
        # scores to probabilities
        specific_logics = torch.nn.functional.softmax(specific_logic_lens, dim=0)

        # get probability of the wanted_idx
        prob = specific_logics[wanted_idx].item()
        # get ranking of the wanted_idx
        smaller = torch.where(specific_logic_lens[wanted_idx] < specific_logic_lens)[0].size()[0]
        ranking = smaller+1
        # rank 1 -> most probable, rank #vocab_size -> least probable
        return prob, ranking

In [None]:
def entropy(probabilities):
    # calculates the entropy of a probability distribution
    if len(probabilities) != model_config.vocab_size:
        probabilities = model_aux.hs_to_probs(probabilities)
    # convert the probabilities to a numpy array
    probabilities = np.array(probabilities.detach())
    # filter out 0 probabilities (to avoid issues with log(0))
    non_zero_probs = probabilities[probabilities != 0]
    entropy = -np.sum(non_zero_probs * np.log2(non_zero_probs))
    return entropy

In [None]:
def get_color_according_to_entropy(hs, max_val=30):
    entropy_score = entropy(hs)
    color_idx = entropy_score / max_val
    color_idx = max(min(color_idx, 1), 0)
    color = f'rgba{cmap_entropy(color_idx)}'
    return color

In [None]:
def get_node_customdata(hs, prefix='', top_or_bottom='top_k', layer_idx=None, target_word=None, color_flag=True):
    '''
    return metadata that uses in the graph plot
    text: the text that will be displayed in the node when hovering over it
    color: the color of the node accodring to its probability of the target word 
    (green if very probable, red if very improbable and grey otherwise)
    '''

    color = default_color 
    hs_meaning = model_aux.hs_to_token_top_k(hs, k_top=5, k_bottom=5)
    res = f'{hs_meaning[top_or_bottom]}'

    if prefix != '':
        res = f'{prefix}: {res}'
    if layer_idx is not None:
        res = f'{layer_idx}) {res}'
    if target_word is not None:
        prob, ranking = logit_status(hs, target_word)
        prob = round(prob*100, round_digits)  # probs [0,1] as percentage [0,100]
        res = f'{res} [status: "{target_word}": prob:{prob}%, rank: {ranking})]'
        if color_flag:
            color_idx = ranking/model_config.vocab_size  # color index [0,1] (according to ranking)
            color = f'rgba{cmap_node_rank_target(color_idx)}'
    return res, color

In [None]:
def layer_mlp_to_graph(layer_idx: int, graph_data, model, hs_collector, row_idx=-1, model_aux=model_aux, target_word=None):
    '''
    create a subgraph of the feed-forward (FF, MLP) part of the model at layer_idx
    the subgraph is a graph of the neurons in the FF part of the model
    the nodes are the most active neurons in the FF (some of positive and some of negative)
    the links are the connections between the neurons (summation of the neurons or when one neuron creates the coefficient of another neuron)
    the graph is created using the graph_data dictionary
    if the graph_data dictionary's lists are empty, they will be initialized
    if the graph_data dictionary 's lists are not empty, they will be updated (try to connect the new nodes to the existing nodes)

    @ layer_idx: the index of the layer in the model
    @ graph_data: the graph data dictionary (if called first time, should include empty list for the keys it uses)
    @ model: the model (for example: gpt2)
    @ hs_collector: the hs_collector dictionary (created from wrapping the model with the hs_collector class)
    @ row_idx: the index of the row in the hs_collector which correspond to the infrence of the #row_idx token. use -1 for the last token (Note: currently not supported any other value than -1)
    @ model_aux: the model_aux class (more functions for the original model)
    @ target_word: the target word for extracting the status of the neurons (ranking and probability)    
    '''
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    mlp_residual = hs_collector[layer_idx]['ln_2']['input'][row_idx]  # mlp_residual == attn_part_out == attn_output+attn_residual
    mlp_input = hs_collector[layer_idx]['mlp']['input'][row_idx]  # mlp_input == c_fc_input
    mlp_out = hs_collector[layer_idx]['mlp']['output'][row_idx]  # only the output of the mlp (not with the residual)
    block_out =hs_collector[layer_idx]['']['output'][row_idx]  # mlp_residual + mlp_out

    mlp_residual_meaning = model_aux.hs_to_token_top_k(mlp_residual, k_top=1, k_bottom=0)
    mlp_input_meaning = model_aux.hs_to_token_top_k(mlp_input, k_top=1, k_bottom=0)
    mlp_out_meaning = model_aux.hs_to_token_top_k(mlp_out, k_top=1, k_bottom=0)
    block_out_meaning = model_aux.hs_to_token_top_k(block_out, k_top=1, k_bottom=0)

    # try to connect to the previous sub blocks (attn_part_out if exists)
    if 'idx_attn_part_out' in graph_data:
         # uses the previous attn_part_out node in the graph as the input to this block, which is actually the residual of this block
         idx_mlp_residual = graph_data['idx_attn_part_out']  # should contain a int value with the index of the previous attn_part_out node in the graph
    else:
        # there is no problem of not having the previous attn_part_out node in the graph. it will assume this is the start of the graph
        print(f'WARNING: did not find idx_attn_part_out for layer_idx={layer_idx}')
        # if no previous node we create a node to represent the input to the block (which is actually the resiual of this block)
        labels.append(mlp_residual_meaning['top_k'][0])
        curr_metadata, curr_color = get_node_customdata(mlp_residual, prefix='mlp_residual', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
        colors_nodes.append(curr_color)
        customdata.append(curr_metadata)
        idx_mlp_residual = len(labels) - 1

    # create nodes for the input after layer norm (mlp_input), output (mlp alone, mlp_out), and the output of the mlp with the residual (residual+mlp_out)
    labels.append(mlp_input_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(mlp_input, prefix='mlp_input', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_mlp_input = len(labels) - 1

    labels.append(mlp_out_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(mlp_out, prefix='mlp_out', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_mlp_out = len(labels) - 1

    labels.append(block_out_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(block_out, prefix='residual+mlp_out', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_block_out = len(labels) - 1

    # connect the last 4 nodes (input, input after layer norm, ouptut of mlp, output of mlp with residual)
    sources.append(idx_mlp_residual)
    targets.append(idx_mlp_input)
    curr_norm = get_norm_layer(mlp_residual)
    weights.append(get_norm_layer(mlp_residual))
    colors_links.append(link_with_normalizer)
    line_explained.append(f'norm:{curr_norm}')

    sources.append(idx_mlp_residual)
    targets.append(idx_block_out)
    curr_norm = get_norm_layer(mlp_residual)
    weights.append(curr_norm)
    colors_links.append('rgba(51,102,153,0.3)')  # a unique color for the mlp residual
    line_explained.append(f'norm:{curr_norm}')

    sources.append(idx_mlp_out)
    targets.append(idx_block_out)
    curr_norm = get_norm_layer(mlp_out)
    weights.append(curr_norm)
    colors_links.append(get_color_according_to_entropy(mlp_out))
    line_explained.append(f'norm:{curr_norm}')

    # get the first and second matricies of the mlp
    c_fc = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.c_fc.weight").clone().detach().cpu()
    c_proj = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.c_proj.weight").clone().detach().cpu()

    values_norm = c_proj.norm(dim=1)
    hs =  hs_collector[layer_idx]['mlp.c_proj']['input'][row_idx]  # value activation. mid results between the key and value matrix
    hs_mul_norm = hs * values_norm  # this is our metric for the importance of the value activation (value*norm of the second matrix)
    
    # pick the top most activate neurons (according to activasion sign)
    for case, n_top, is_largest in [('top_k', number_of_top_neurons, True), ('bottom_k', number_of_bottom_neurons, False)]:
        tops = torch.topk(hs_mul_norm, k=n_top, largest=is_largest)
        for entry_idx, activision_value_mul_norm in zip(tops.indices, tops.values):
            activision_value_mul_norm = round(activision_value_mul_norm.item(), round_digits)
            activision_value = round(hs[entry_idx].item(), round_digits)
            entry_idx = entry_idx.item()

            # create a node for the "key" neuron (the first matrix)
            idx_key = len(labels)
            idx_value = idx_key+1
            curr_c_fc_meaning = model_aux.hs_to_token_top_k(c_fc.T[entry_idx], k_top=1, k_bottom=0)
            labels.append(curr_c_fc_meaning['top_k'][0])
            curr_metadata, curr_color = get_node_customdata(c_fc.T[entry_idx], prefix=f'key{entry_idx}', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)

            # create a node for the "value" neuron (the second matrix)
            curr_c_proj_meaning = model_aux.hs_to_token_top_k(c_proj[entry_idx], k_top=1, k_bottom=1)
            labels.append(curr_c_proj_meaning[case][0])  # value is chosen accodring to activation sign. suppouse to reflect the meaning its adding to the output
            curr_metadata, curr_color = get_node_customdata(c_proj[entry_idx], prefix=f'value{entry_idx}', top_or_bottom=case, layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)

            
            if compact_mlp_nodes:
                merged_idx = merge_two_nodes(graph_data, index1=idx_key, index2=idx_value, prefix1='key:', prefix2='value:')
                idx_key = merged_idx
                idx_value = merged_idx
            else:
                # connect the "key" and the "value"
                sources.append(idx_key)
                targets.append(idx_value)
                weights.append(abs(activision_value)*factor_weight_mlp_key_value_link)
                colors_links.append(positive_activation if is_largest else negative_activation)
                line_explained.append(f'activision:{activision_value}')

            # add between the mlp_input and the "key" neuron
            sources.append(idx_mlp_input)
            targets.append(idx_key)
            weights.append(defualt_weight)
            colors_links.append(get_color_according_to_entropy(c_fc.T[entry_idx]))
            line_explained.append(f'')

            # add a link between of the "value" neuron to the mlp_out
            sources.append(idx_value)
            targets.append(idx_mlp_out)
            weights.append(abs(activision_value_mul_norm))
            colors_links.append(get_color_according_to_entropy(c_proj[entry_idx]))
            line_explained.append(f'activision*norm:{abs(activision_value_mul_norm)}')
    
    # we also add neurons representing the matricies bias vectors
    # we create nodes for each matricies bias vector then connect them to the flow
    c_fc_bias = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.c_fc.bias").clone().detach().cpu() @ c_proj
    c_fc_bias_meaning = model_aux.hs_to_token_top_k(c_fc_bias, k_top=1, k_bottom=0)
    labels.append(c_fc_bias_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(c_fc_bias, prefix='c_fc_bias', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_bias_vector) # special color for bias
    idx_c_fc_bias = len(labels) - 1

    c_proj_bias = utils.rgetattr(model, f"transformer.h.{layer_idx}.mlp.c_proj.bias").clone().detach().cpu()
    c_proj_bias_meaning = model_aux.hs_to_token_top_k(c_proj_bias, k_top=1, k_bottom=0)
    labels.append(c_proj_bias_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(c_proj_bias, prefix='c_proj_bias', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_bias_vector)
    idx_c_proj_bias = len(labels) - 1

    if compact_mlp_nodes:
        merged_idx = merge_two_nodes(graph_data, index1=idx_c_fc_bias, index2=idx_c_proj_bias, prefix1='c_fc_bias:', prefix2='c_proj_bias:')
        idx_c_fc_bias = merged_idx
        idx_c_proj_bias = merged_idx
    else:
        # connect c_fc_bias to c_proj_bias, although it is not really a link (this why its color is invisible_link_color)
        sources.append(idx_c_fc_bias)
        targets.append(idx_c_proj_bias)
        weights.append(0.05) # to be barely visible (trick to make it unvisible with the background)
        colors_links.append(invisible_link_color)
        line_explained.append(f'norm:{curr_norm}')

    # connect mlp_input to c_fc_bias
    sources.append(idx_mlp_input)
    targets.append(idx_c_fc_bias)
    curr_norm = get_norm_layer(mlp_input)
    weights.append(curr_norm)
    colors_links.append(color_for_bias_vector)
    line_explained.append(f'norm:{curr_norm}')

    # connect c_proj_bias to mlp_out
    sources.append(idx_c_proj_bias)
    targets.append(idx_mlp_out)
    curr_norm = get_norm_layer(c_proj_bias)
    weights.append(curr_norm)
    colors_links.append(color_for_bias_vector)
    line_explained.append(f'norm:{curr_norm}')

    # update index to last mlp_out so the next attention block will connect to it
    graph_data['idx_block_out'] = idx_block_out

In [None]:
def layer_attn_to_graph(layer_idx, graph_data, hs_collector, model, row_idx=-1, model_aux=model_aux, target_word=None, line=None):
    '''
    create a graph for the attention (attn) layer
    the subgraph is a graph of the neurons in the Q, K, V O matricies, mostly aggregated into heads
    the nodes are single or small groups of neurons (when they are aggregated into heads or subheads)
    the links are the connections between the neurons (summation of the neurons or when one neuron creates the coefficient of another neuron)
    the graph is created using the graph_data dictionary
    if the graph_data dictionary's lists are empty, they will be initialized
    if the graph_data dictionary 's lists are not empty, they will be updated (try to connect the new nodes to the existing nodes)

    @ layer_idx: the index of the layer in the model
    @ graph_data: the graph data dictionary (if called first time, should include empty list for the keys it uses)
    @ model: the model (for example: gpt2)
    @ hs_collector: the hs_collector dictionary (created from wrapping the model with the hs_collector class)
    @ row_idx: the index of the row in the hs_collector which correspond to the infrence of the #row_idx token. use -1 for the last token (Note: currently not supported any other value than -1)
    @ model_aux: the model_aux class (more functions for the original model)
    @ target_word: the target word for extracting the status of the neurons (ranking and probability)
    @ line: the line that was used to generate the data in hs_collector (the prompt to the model)
    '''
    sources = graph_data['sources']
    targets = graph_data['targets']
    weights = graph_data['weights']
    colors_nodes = graph_data['colors_nodes']
    colors_links = graph_data['colors_links']
    labels = graph_data['labels']
    line_explained = graph_data['line_explained']
    customdata = graph_data['customdata']

    # uses to show what was the token that generated the attention memory (previous keys and layer)
    # the i-th key and i-th value were generated by the i-th token
    # if @line is not given (None) - will not show this information
    parsed_line = None
    if line is not None:
        parsed_line = tokenizer.encode(line, return_tensors='pt')
        # save the parsed line for later use
        parsed_line = [tokenizer.decode(x.item()) for x in parsed_line[0]]

    
    attn_residual = hs_collector[layer_idx]['']['input'][row_idx]
    attn_residual_meaning = model_aux.hs_to_token_top_k(attn_residual, k_top=1, k_bottom=0)

    # try to connect to previous layer
    if 'idx_block_out' in graph_data: # previous layer exists
        idx_attn_residual = graph_data['idx_block_out']
    else:
        # there is no problem of not having the previous idx_block_out node in the graph. it will assume this is the start of the graph
        print(f'WARNING: did not find idx_block_out for layer_idx={layer_idx}')
        # create the input (before layer norm) to the attn block, which is actually the residual from the previous block
        labels.append(attn_residual_meaning['top_k'][0])
        curr_metadata, curr_color = get_node_customdata(attn_residual, prefix=f'attn_residual', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
        customdata.append(curr_metadata)
        colors_nodes.append(curr_color)
        idx_attn_residual = len(labels) - 1

    # create nodes for the attention input (after layer norm) and for the attention output
    # attn input (after layer norm. the layer norm is done before the attentnion module)
    attn_input = hs_collector[layer_idx]['attn']['input'][row_idx]
    attn_input_meaning = model_aux.hs_to_token_top_k(attn_input, k_top=1, k_bottom=0)
    labels.append(attn_input_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(attn_input, prefix=f'attn_input', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_attn_input = len(labels) - 1

    # attn output
    attn_out = hs_collector[layer_idx]['attn']['output'][row_idx]
    attn_out_meaning = model_aux.hs_to_token_top_k(attn_out, k_top=1, k_bottom=0)
    labels.append(attn_out_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(attn_out, prefix=f'attn_out', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_attn_out = len(labels) - 1

    # link between prev_block_output/attn_residual and attn_input (the process of ln_1)
    sources.append(idx_attn_residual)
    targets.append(idx_attn_input)
    curr_weight = get_norm_layer(attn_residual)
    weights.append(curr_weight)
    colors_links.append(link_with_normalizer)
    line_explained.append(f'norm:{curr_weight}')

    c_attn = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.c_attn.weight").clone().detach().cpu() # W_QKV (the QKV matrix)
    c_proj = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.c_proj.weight").clone().detach().cpu() # W_O (the Output/projection matrix)

    # in gpt2 c_attn is the concatenation of Wq, Wk, Wv (Q, K, V)
    Wq = c_attn[:, :hidden_d]
    Wk = c_attn[:, hidden_d:2*hidden_d]
    Wv = c_attn[:, 2*hidden_d:]

    # we can get the output of each of the Q, K, V matrices by splitting the output of the c_attn
    c_attn_output = hs_collector[layer_idx]['attn.c_attn']['output'][row_idx].cpu()
    q = c_attn_output[ :hidden_d]  # this layer query
    k = c_attn_output[hidden_d:2*hidden_d]  # this layer key. it is added to the attention memory (to "past_key_values" so also the next tokens can use it)
    v = c_attn_output[2*hidden_d:] # this layer value. like the key, it is added to the attention memory 

    # projection using the QK circuit
    def pre_project_q(hs_q):
        return hs_q @ Wk

    # projection using the QK circuit
    def pre_project_k(hs_k):
        return Wq @ hs_k
    
    # projection using the OV circuit
    def pre_project_v(hs_v):
        return hs_v @ c_proj
    
    q_projected = pre_project_q(q)
    k_projected = pre_project_k(k)
    v_poject = pre_project_v(v)  

    # create a node for q,k,v together
    q_meaning = model_aux.hs_to_token_top_k(q_projected, k_top=1, k_bottom=0)
    q_data, curr_color = get_node_customdata(q_projected, prefix=f'q (for current calc)', top_or_bottom='top_k', target_word=target_word)
    k_data, _ = get_node_customdata(k_projected, prefix=f'k (for next tokens)', top_or_bottom='top_k', target_word=target_word)
    v_data, _ = get_node_customdata(v_poject, prefix=f'v (for next tokens)', top_or_bottom='top_k', target_word=target_word)

    curr_metadata = f'q,k,v (before splitting into heads):' + '<br />' + q_data + '<br />' + k_data + '<br />' + v_data

    # create a new node for query (q) and add the key and value (k, v) metadata. we call this node qkv_full
    labels.append(q_meaning['top_k'][0])
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_qkv_full = len(labels) - 1

    # connect between attn_input and qkv_full   
    sources.append(idx_attn_input)
    targets.append(idx_qkv_full)
    curr_weight = get_norm_layer(attn_input)
    weights.append(curr_weight)
    colors_links.append(get_color_according_to_entropy(q))
    line_explained.append(f'norm:{curr_weight}')

    # the concated results from all the heads but without the OV circuit projection
    concated_heads_wihtout_projection = hs_collector[layer_idx]['attn.c_proj']['input'][row_idx]

    idx_concated_heads = len(labels)  # attn_c_proj input
    concated_heads_without_projection_meaning = model_aux.hs_to_token_top_k(concated_heads_wihtout_projection, k_top=1, k_bottom=0)
    labels.append(concated_heads_without_projection_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(concated_heads_wihtout_projection, prefix=f'concated_heads (without projection)', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_abstract)

    # link attn_output_and_residual to block_output
    attn_output_and_residual = attn_out + attn_residual # equal to hs_collector[layer_idx]['ln_2']['input'][row_idx]
    attn_output_and_residual_meaning = model_aux.hs_to_token_top_k(attn_output_and_residual, k_top=1, k_bottom=0)
    labels.append(attn_output_and_residual_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(attn_output_and_residual, prefix=f'attn_out+residual', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(curr_color)
    idx_attn_output_and_residual = len(labels) - 1

    # link attn_output_and_residual to attn_c_proj (the summation of this attn block to the residual)
    sources.append(idx_attn_out)
    targets.append(idx_attn_output_and_residual)
    curr_weight = get_norm_layer(attn_out)
    weights.append(curr_weight)
    colors_links.append(get_color_according_to_entropy(attn_out))
    line_explained.append(f'norm: {curr_weight}')

    # the residual connection
    sources.append(idx_attn_residual)
    targets.append(idx_attn_output_and_residual)
    curr_weight = get_norm_layer(attn_residual)
    weights.append(curr_weight)
    colors_links.append('rgba(102,0,204,0.3)')  # unique color for the attn residual
    line_explained.append(f'norm: {curr_weight}')

    # print the hidden meaning of each of the heads
    dim_head = model.config.n_embd // model.config.n_head

    # create for each head the following nodes:
    # (1) the qi - this head part in the query q (we also add its information about ki, vi that were generated at this layer and saved to the attention memory)
    # (2) its #n_values_per_head top biggest ki and vi (the keys and values from the attention memory) accodring to the attention score
    # (3) the head output - the weighted summation of all the vi into it 
    for head_idx in range(model.config.n_head):
        # hs_collector['past_key_values'][layer_idx][0 for key, 1 for value][entry in batch][head_idx] -> list of the keys/values for this head. the i-th entry is the key/value for the i-th token in the input
        keys = hs_collector['past_key_values'][layer_idx][0][0][head_idx]  
        values = hs_collector['past_key_values'][layer_idx][1][0][head_idx]
        attentions = hs_collector['attentions'][layer_idx][0][head_idx][row_idx]

        # qi with the information about ki, vi (1)
        qi = q[dim_head * head_idx: dim_head * (head_idx + 1)]
        ki = k[dim_head * head_idx: dim_head * (head_idx + 1)]
        vi = v[dim_head * head_idx: dim_head * (head_idx + 1)]

        q_i_fill = torch.zeros(model.config.n_embd)
        q_i_fill[dim_head*head_idx:dim_head*(head_idx+1)] = qi
        q_i_projected = pre_project_q(q_i_fill)

        k_i_fill = torch.zeros(model.config.n_embd)
        k_i_fill[dim_head*head_idx:dim_head*(head_idx+1)] = ki
        k_i_projected = pre_project_k(k_i_fill)

        v_i_fill = torch.zeros(model.config.n_embd)
        v_i_fill[dim_head*head_idx:dim_head*(head_idx+1)] = vi
        v_i_projected = pre_project_v(v_i_fill) 

        q_i_projected_meaning = model_aux.hs_to_token_top_k(q_i_projected, k_top=1, k_bottom=0)
        q_data, curr_color = get_node_customdata(q_i_projected, prefix=f'qi (for current calc)', top_or_bottom='top_k', target_word=target_word)
        k_data, _ = get_node_customdata(k_i_projected, prefix=f'ki (for next tokens)', top_or_bottom='top_k', target_word=target_word)
        v_data, _ = get_node_customdata(v_i_projected, prefix=f'vi (for next tokens)', top_or_bottom='top_k', target_word=target_word)

        curr_metadata = f'head {head_idx}:' + '<br />' + q_data + '<br />' + k_data + '<br />' + v_data

        # create a new node for query and add the key and value metadata
        labels.append(q_i_projected_meaning['top_k'][0])
        customdata.append(curr_metadata)
        colors_nodes.append(curr_color)
        idx_q_i_projected = len(labels) - 1

        # create a link between the input and the head query 
        sources.append(idx_qkv_full)
        targets.append(idx_q_i_projected)
        curr_weight = get_norm_layer(q_i_projected)
        weights.append(curr_weight*factor_head_norm)
        colors_links.append(get_color_according_to_entropy(q_i_projected))
        line_explained.append(f'norm: {curr_weight}')

        # create a new node for head output (3)
        head_output = torch.zeros(model.config.n_embd)
        head_output[dim_head*head_idx:dim_head*(head_idx+1)] = concated_heads_wihtout_projection[dim_head*head_idx:dim_head*(head_idx+1)]
        head_output_projected = pre_project_v(head_output)
        head_output_projected_meaning = model_aux.hs_to_token_top_k(head_output_projected, k_top=1, k_bottom=0)

        labels.append(head_output_projected_meaning['top_k'][0])
        curr_metadata, curr_color = get_node_customdata(head_output_projected, prefix=f'head {head_idx}: output', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
        customdata.append(curr_metadata)
        colors_nodes.append(curr_color)
        idx_head_output_projected = len(labels) - 1

        # create a link between the head output and the concated heads
        sources.append(idx_head_output_projected)
        targets.append(idx_concated_heads)
        curr_weight = get_norm_layer(head_output_projected)
        weights.append(curr_weight*factor_head_output)
        colors_links.append(get_color_according_to_entropy(head_output_projected))
        line_explained.append(f'norm after projection:{curr_weight}')
        

        # the nodes representing the #n_values_per_head top biggest ki and vi (2)
        best_head_vals = attentions.topk(n_values_per_head, dim=0)
        for attn_val_idx, attn_score in zip(best_head_vals.indices, best_head_vals.values):
            attn_score = round(attn_score.item(), round_digits)
            keys_from_attn = keys[attn_val_idx]  # should be in the size of the subhead (for example, 64 for gpt2-medium)
            values_from_attn = values[attn_val_idx]  # should be in the size of the subhead

            keys_from_attn_proj = torch.zeros(model.config.n_embd)
            keys_from_attn_proj[dim_head*head_idx:dim_head*(head_idx+1)] = keys_from_attn
            keys_from_attn_proj = pre_project_k(keys_from_attn_proj)

            values_from_attn_proj = torch.zeros(model.config.n_embd)
            values_from_attn_proj[dim_head*head_idx:dim_head*(head_idx+1)] = values_from_attn
            values_from_attn_proj = pre_project_v(values_from_attn_proj)

            keys_from_attn_proj_meaning = model_aux.hs_to_token_top_k(keys_from_attn_proj, k_top=1, k_bottom=0)
            values_from_attn_proj_meaning = model_aux.hs_to_token_top_k(values_from_attn_proj, k_top=1, k_bottom=0)

            # create node for the top keys ki
            labels.append(keys_from_attn_proj_meaning['top_k'][0])
            if parsed_line is not None:
                curr_metadata, curr_color = get_node_customdata(keys_from_attn_proj, prefix=f'head {head_idx}: key {attn_val_idx} [created from "{parsed_line[attn_val_idx]}"]', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            else:
                curr_metadata, curr_color = get_node_customdata(keys_from_attn_proj, prefix=f'head {head_idx}: key {attn_val_idx}', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            idx_keys_from_attn_proj = len(labels) - 1

            # create node for the top values vi
            labels.append(values_from_attn_proj_meaning['top_k'][0])
            if parsed_line is not None:
                curr_metadata, curr_color = get_node_customdata(values_from_attn_proj, prefix=f'head {head_idx}: value {attn_val_idx} [created from "{parsed_line[attn_val_idx]}"]', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            else:
                curr_metadata, curr_color = get_node_customdata(values_from_attn_proj, prefix=f'head {head_idx}: value {attn_val_idx}', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            idx_values_from_attn_proj = len(labels) - 1

            # create a link between the query qi and the top keys ki
            color_attn = f'rgba{cmap_attn_score(attn_score)}'

            sources.append(idx_q_i_projected)
            targets.append(idx_keys_from_attn_proj)
            curr_weight = get_norm_layer(q_i_projected)
            weights.append(max(attn_score*factor_attn_score, 0.51))  # in case the attentnion score is to low and the line might not be visible
            colors_links.append(color_attn)
            line_explained.append(f'attention score: {attn_score} (qi norm: {curr_weight})')

            if compact_attn_k_v_nodes:
                merged_idx = merge_two_nodes(graph_data, index1=idx_keys_from_attn_proj, index2=idx_values_from_attn_proj, prefix1='ki:', prefix2='vi:')
                idx_keys_from_attn_proj = merged_idx
                idx_values_from_attn_proj = merged_idx
            else:
                # create the link between the top keys and the top values
                sources.append(idx_keys_from_attn_proj)
                targets.append(idx_values_from_attn_proj)
                curr_weight = get_norm_layer(keys_from_attn_proj)
                weights.append(max(attn_score*factor_attn_score, 0.51))
                colors_links.append(color_attn)
                line_explained.append(f'attention score: {attn_score} (ki norm: {curr_weight})')

            # create a link between the top values and the idx_head_output_projected
            sources.append(idx_values_from_attn_proj)
            targets.append(idx_head_output_projected)
            curr_weight = get_norm_layer(values_from_attn_proj)
            weights.append(max(attn_score*curr_weight*factor_for_memory_to_head_values, 0.51))
            colors_links.append(get_color_according_to_entropy(values_from_attn_proj))
            line_explained.append(f'attention score * norm: {attn_score*curr_weight} (vi norm: {curr_weight})')
    
    # now we want to present single neurons from the concatenated heads and how they are projected (indevideually) to the output by W_O (the output projection matrix)
    # we pick only the top most activated neurons (positive and negative)
    concated_heads_wihtout_projection_mul_norm = concated_heads_wihtout_projection * c_proj.norm(dim=1)
    for case, n_top, is_largest in [('top_k', number_of_top_neurons, True), ('bottom_k', number_of_bottom_neurons, False)]:
        tops = torch.topk(concated_heads_wihtout_projection_mul_norm, k=n_top, largest=is_largest)
        for entry_idx, activision_mul_norm in zip(tops.indices, tops.values):
            activision_mul_norm = round(activision_mul_norm.item(), round_digits)
            entry_idx = entry_idx.item()
            activision_value = round(concated_heads_wihtout_projection[entry_idx].item(), round_digits)
                
            # connect between c_proj_input, which is the concatenated heads, and each of this neurons
            idx_value = len(labels)
            sources.append(idx_concated_heads)
            targets.append(idx_value)
            weights.append(defualt_weight)
            colors_links.append(positive_activation if is_largest > 0 else negative_activation)
            curr_c_proj_meaning = model_aux.hs_to_token_top_k(c_proj[entry_idx], k_top=1, k_bottom=1)
            labels.append(curr_c_proj_meaning[case][0])
            curr_metadata, curr_color = get_node_customdata(c_proj[entry_idx], prefix=f'value index:{entry_idx} (from head {entry_idx//(model.config.n_embd//model.config.n_head)}), activision:{activision_value}: ', 
                            top_or_bottom=case, layer_idx=layer_idx, target_word=target_word)  # value is chosen accodring to activation sign. suppouse to reflect the meaning its adding to the output
            customdata.append(curr_metadata)
            colors_nodes.append(curr_color)
            line_explained.append(f'activision:{activision_value}')

            # add a link between of the last "targets" to the attn_output
            sources.append(idx_value)
            targets.append(idx_attn_out)
            weights.append(abs(activision_mul_norm))
            colors_links.append(get_color_according_to_entropy(c_proj[entry_idx]))
            line_explained.append(f'activision*norm:{activision_mul_norm}')
    
    # we also add neurons representing the W_O matrix bias vectors
    c_proj_bias = utils.rgetattr(model, f"transformer.h.{layer_idx}.attn.c_proj.bias").clone().detach().cpu()
    c_proj_bias_meaning = model_aux.hs_to_token_top_k(c_proj_bias, k_top=1, k_bottom=0)
    labels.append(c_proj_bias_meaning['top_k'][0])
    curr_metadata, curr_color = get_node_customdata(c_proj_bias, prefix='c_proj_bias', top_or_bottom='top_k', layer_idx=layer_idx, target_word=target_word)
    customdata.append(curr_metadata)
    colors_nodes.append(color_for_bias_vector)
    idx_c_proj_bias = len(labels) - 1

    # connect the bias vector
    sources.append(idx_concated_heads)
    targets.append(idx_c_proj_bias)
    curr_norm = get_norm_layer(c_proj_bias)
    weights.append(defualt_weight)
    colors_links.append(color_for_bias_vector)
    line_explained.append(f'norm: {curr_norm}')

    # connect bias to attn_out
    sources.append(idx_c_proj_bias)
    targets.append(idx_attn_out)
    curr_norm = get_norm_layer(c_proj_bias)
    weights.append(curr_norm)
    colors_links.append(color_for_bias_vector)
    line_explained.append(f'norm: {curr_norm}')

    graph_data['idx_attn_part_out'] = idx_attn_output_and_residual

In [None]:
def gen_basic_graph(layers, hs_collector, model, model_aux=model_aux, target_word=None, line=None, save_html=False):
    '''
    A wrapper function to generate a graph for a given layers

    @layers: a list of layers to generate the graph for (correctness is guaranteed only if layers are in order)
    @ hs_collector: the hs_collector dictionary (created from wrapping the model with the hs_collector class)
    @ model: the model (for example: gpt2)
    @ model_aux: the model_aux class (more functions for the original model)
    @ target_word: the target word for extracting the status of the neurons (ranking and probability)
    @ line: the line that was used to generate the data in hs_collector (the prompt to the model)
    @ save_html: if True, the graph will be saved as an html file to {title}.html. if @ save_html is a non empty string, the graph will be saved as an html file to {save_html}.html
    '''

    model = model.cpu() # all the calculations are done on the cpu (also assuming that hs_collector is on the cpu)

    # init the graph data
    graph_data = {
        'sources': [],
        'targets': [],
        'weights': [],
        'colors_nodes': [],
        'colors_links': [],
        'labels': [],
        'line_explained': [],
        'customdata': []
    }

    if type(layers) != list:
        layers = [layers]
    
    # generate the graph for each layer
    # each layer is a block of two sub-blocks: the attention block and the mlp block
    for layer_idx in layers:
        layer_attn_to_graph(layer_idx, graph_data, hs_collector=hs_collector, model=model, model_aux=model_aux, target_word=target_word, line=line)
        layer_mlp_to_graph(layer_idx, graph_data, hs_collector=hs_collector, model=model, model_aux=model_aux, target_word=target_word)
    
    plot_graph_aux(graph_data, title=f'Flow-Grpah of layers {layers}--> propt: "{line}". target: "{target_word}"', save_html=save_html)

In [None]:
gen_basic_graph(10, model=model, hs_collector=hs_collector, target_word=target_token, line=line)

In [None]:
# # example of usage
# gen_basic_graph([layer_index for layer_index in range(model_config.n_layer-3, model_config.n_layer)], model=model, hs_collector=hs_collector, target_word=target_token, line=line)

In [None]:
# # more examples
# gen_basic_graph([0], model=model, hs_collector=hs_collector, target_word=target_token, line=line)
# gen_basic_graph(8, model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html='./tmp123.html')
# gen_basic_graph([9, 10], model=model, hs_collector=hs_collector, target_word=target_token, line=line, save_html=True)
# gen_basic_graph([11], model=model, hs_collector=hs_collector, target_word=target_token, line=line)