## Setup

### GPU Usage

In [1]:
!nvidia-smi
!nvidia-smi -L

Sun Sep 29 22:35:13 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A30                     Off |   00000000:02:00.0 Off |                    0 |
| N/A   28C    P0             25W /  165W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

GPU 0: NVIDIA A30 (UUID: GPU-7e8c468b-8390-553b-ce55-5c71d0dd56c5)


### Imports

In [2]:
%load_ext autoreload
%autoreload 2

from time_series_generation import *
from phid import *
from graph_theoretical_analysis import *
from cognitive_tasks_analysis import *
from cognitive_tasks_vs_syn_red_analysis import *
from lda import *
from random_walk_time_series import *
from hf_token import TOKEN

from huggingface_hub import login
from transformers import AutoTokenizer, AutoConfig 
import seaborn as sns
import matplotlib.pyplot as plt

### Loading the Model

In [3]:
if constants.LOAD_MODEL:
    device = torch.device("cuda")
    login(token = TOKEN)
    attn_implementation="eager" # GEMMA_ATTENTION_CLASSES = {"eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, "sdpa": GemmaSdpaAttention,}


    # Load the configuration and modify it
    model_config = AutoConfig.from_pretrained(constants.MODEL_NAME, cache_dir=constants.CACHE_DIR_BITBUCKET)
    model_config._attn_implementation = attn_implementation  # Custom attention parameter

    # Load the tokenizer and model with the modified configuration
    tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME, cache_dir=constants.CACHE_DIR_BITBUCKET)
    model = AutoModelForCausalLM.from_pretrained(
        constants.MODEL_NAME,
        cache_dir=constants.CACHE_DIR_BITBUCKET,
        device_map='auto',
        attn_implementation=attn_implementation, # Make sure to use the adequate attention layer in order to 
        config=model_config,  # Use the modified config
    )

    model.eval()
    print("Loaded Model Name: ", model.config.name_or_path)
    print("Model: ", model)
    print("Attention Layers Implementation: ", model.config._attn_implementation)
    print(f"Number of layers: {constants.NUM_LAYERS}")
    print(f"Number of attention heads per layer: {constants.NUM_HEADS_PER_LAYER}")

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /homes/pu22/.cache/huggingface/token
Login successful
Initializing GemmaAttention with Index: 0, Ablated Heads: None
Initializing GemmaAttention with Index: 1, Ablated Heads: None
Initializing GemmaAttention with Index: 2, Ablated Heads: None
Initializing GemmaAttention with Index: 3, Ablated Heads: None
Initializing GemmaAttention with Index: 4, Ablated Heads: None
Initializing GemmaAttention with Index: 5, Ablated Heads: None
Initializing GemmaAttention with Index: 6, Ablated Heads: None
Initializing GemmaAttention with Index: 7, Ablated Heads: None
Initializing GemmaAttention with Index: 8, Ablated Heads: None
Initializing GemmaAttention with Index: 9, Ablated Heads: None
Initializing GemmaAttention with Index: 10, Ablated Heads: None
Initializing GemmaAttention with Index: 11, Ablated Hea

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded Model Name:  google/gemma-1.1-2b-it
Model:  GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): 

### Simple Prompt

In [4]:
if constants.LOAD_MODEL:
    prompt = "What is the sum of 457 and 674? Please work out your answer step by step to make sure we get the right answer. "
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=20,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            attention_mask=inputs['attention_mask']  # Provide attention mask for reliable results
        )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(generated_text)

### Autoregresive Sampling

In [None]:
if constants.LOAD_MODEL:
    # prompt = "Find the grammatical error in the following sentence: She go to the store and buy some milk"
    prompt = "What is the sum of 457 and 674?"# Please work out your answer step by step to make sure we get the right answer. "
    # prompt = "How much is 2 multiplied by 8?"
    # prompt = "Write a very creative story about a dragon that lives in a cave and breathes fire"
    num_tokens_to_generate = 20
    generated_text, attention_params = generate_text_with_attention(model, tokenizer, num_tokens_to_generate, device, prompt=prompt, 
        temperature=0.3, modified_output_attentions=constants.MODIFIED_OUTPUT_ATTENTIONS)
    time_series = compute_attention_metrics_norms(attention_params, constants.METRICS_TRANSFORMER, num_tokens_to_generate, aggregation_type='norm')
    print(attention_params[constants.ATTENTION_MEASURE].shape)
    print(generated_text)

## Attention Head Ablation Experiments

In [4]:
import ablation_studies
import utils

#### Recording Logits in Autoregressive Sampling

In [7]:
# No ablation
model.reset_ablated_heads()
model.print_ablated_heads()

prompt = "What is the sum of 457 and 674?"# Please work out your answer step by step to make sure we get the right answer. "
num_tokens_to_generate = 20

generated_text, logits_list, generated_ids = ablation_studies.generate_text_with_logits(model, tokenizer, num_tokens_to_generate=num_tokens_to_generate, 
                                                                         device=device, prompt=prompt, temperature=0.0)
print(f'Lenght of logits_list: {len(logits_list)}, Shape of logits: {logits_list[0].shape}')
print(generated_text)

No attention heads have been ablated
Lenght of logits_list: 20, Shape of logits: torch.Size([1, 256000])
What is the sum of 457 and 674?

The sum of 457 and 674 is 1131.


In [8]:
# Ablation of some attention heads
ablated_attention_heads = {
    0: [], 1: [1,2,3,4,5], 2: [], 3: [1,2,5,6], 4: [3], 5: [1,6,7], 6: [1,2,3,7], 7: [1,], 8: [], 9: [], 10: [], 11: [], 12: [], 13: [], 14: [], 15: [1,2,3,4,5,6], 16: [], 17: []        
}
model.set_ablated_heads(ablated_attention_heads)
model.print_ablated_heads()


generated_text_ablated, logits_list_ablated, generated_ids_ablated = ablation_studies.generate_text_with_logits(model, tokenizer, num_tokens_to_generate=num_tokens_to_generate, 
                                                                           device=device, prompt=prompt, temperature=0.0)

print(f'Lenght of logits_list: {len(logits_list_ablated)}, Shape of logits: {logits_list[0].shape}')
print(generated_text_ablated)
print(f'Same logits: {torch.allclose(logits_list[0], logits_list_ablated[0], atol=1e-6)}')

Layer 1: [1, 2, 3, 4, 5]
Layer 3: [1, 2, 5, 6]
Layer 4: [3]
Layer 5: [1, 6, 7]
Layer 6: [1, 2, 3, 7]
Layer 7: [1]
Layer 15: [1, 2, 3, 4, 5, 6]
Lenght of logits_list: 20, Shape of logits: torch.Size([1, 256000])
What is the sum of 457 and 674?

The sum of 457 and 674 is 1131.
Same logits: False


### Computing Divergence of Ablated Model

In [9]:
prompt = "What is the sum of 457 and 674?"# Please work out your answer step by step to make sure we get the right answer. "
num_tokens_to_generate = 20
ablated_attention_heads = {
    0: [], 1: [1,2,3,4,5], 2: [], 3: [1,2,5,6], 4: [3], 5: [1,6,7], 6: [1,2,3,7], 7: [1,], 8: [0,1,2,3,4,5,6,7], 9: [], 10: [], 11: [], 12: [], 13: [], 14: [], 15: [0,1,2,3,4,5,6,7], 16: [], 17: []        
}

ablated_logits, divergence_list, generated_text = ablation_studies.generate_with_teacher_forcing_ablated(model, tokenizer, original_generated_ids=generated_ids, 
    original_logits=logits_list, device=device, temperature=0.0, ablated_attention_heads=ablated_attention_heads,verbose=False)

print(f'Divergence Sum: {sum(divergence_list)}')
print(generated_text)

Divergence Sum: 6.320128690661633
What is the sum of 457 and 674?

Answer sum of 457 and 674 is 1131.


### Ablate Several Attention Heads Iteratively Randomly

In [None]:
import random
import matplotlib.pyplot as plt

# Number of iterations to repeat the process
num_random_ablations = 3  # Adjust this value as needed
K = 10  # You can set this to any integer value as needed

# Initialize a list to store the divergence trajectories from each iteration
all_divergence_trajectories = []

for iteration in range(num_random_ablations):
    # Generate a list of all attention heads (layer, head)
    attention_heads_list = [(layer, head) for layer in range(constants.NUM_LAYERS) for head in range(constants.NUM_HEADS_PER_LAYER)]
    
    # Randomly shuffle the list to determine the ablation order
    random.shuffle(attention_heads_list)
    
    # Set the fixed amount of heads to ablate in each iteration

    # Initialize lists to record divergences and the number of heads ablated for this iteration
    divergence_list = []
    num_heads_ablated_list = []

    # Loop over the number of steps required to ablate all heads
    for i in range(0, constants.NUM_TOTAL_HEADS, K):
        # Get the current list of heads to ablate
        current_heads_to_ablated = attention_heads_list[:i+K]
        
        # Initialize the ablated_attention_heads dictionary
        ablated_attention_heads = {}
        for layer, head in current_heads_to_ablated:
            if layer not in ablated_attention_heads:
                ablated_attention_heads[layer] = []
            ablated_attention_heads[layer].append(head)
        
        # Call your method to generate the output with the current ablated heads
        ablated_logits, divergence, generated_text = ablation_studies.generate_with_teacher_forcing_ablated(
            model, 
            tokenizer, 
            original_generated_ids=generated_ids, 
            original_logits=logits_list, 
            device=device, 
            temperature=0.0, 
            ablated_attention_heads=ablated_attention_heads,
            verbose=False
        )
        
        # Record the sum of divergences
        divergence_sum = sum(divergence)
        divergence_list.append(divergence_sum)
        num_heads_ablated_list.append(len(current_heads_to_ablated))
    
    # Store the divergence trajectory for this iteration
    all_divergence_trajectories.append(divergence_list)
    
    # Optional: Print progress for each iteration
    print(f'Completed iteration {iteration + 1}.')

# Plot the divergence trajectories for all iterations
plt.figure(figsize=(10, 6))

for divergence_trajectory in all_divergence_trajectories:
    plt.plot(num_heads_ablated_list, divergence_trajectory, marker='o')

plt.xlabel('Number of Heads Ablated')
plt.ylabel('Divergence Sum')
plt.title('Divergence vs. Number of Ablated Attention Heads (Multiple Iterations)')
plt.grid(True)
plt.savefig(constants.PLOT_ABLATIONS + 'random_ablations_divergence_trajectories.png')
plt.show()


### Ablate Most Synergistic Heads First

In [5]:
prompt_category_name = 'average_prompts'
global_matrices, synergy_matrices, redundancy_matrices = load_matrices(base_save_path=constants.MATRICES_DIR + prompt_category_name + '/' + prompt_category_name + '.pt')
averages = calculate_average_synergy_redundancies_per_head(synergy_matrices, redundancy_matrices, within_layer=False)
gradient_ranks = compute_gradient_rank(averages)
gradient_ranks = gradient_ranks["attention_weights"]

In [8]:
import random, os, json
import matplotlib.pyplot as plt


prompt = "What is the sum of 457 and 674?"# Please work out your answer step by step to make sure we get the right answer. "
num_tokens_to_generate = 50

# Number of iterations to repeat the random ablation process
num_random_ablations = 5  # Adjust this value as needed
num_heads_skip_per_iteration = 5  # The fixed number of heads to ablate in each iteration

# The resulting divergence trajectories for random and gradient rank-based ablations for each prompt
# The first key is the prompt category name, and the second key is the prompt number, third is whether it is random or gradient 
# The value is a list of divergence trajectories for each iteration
divergence_results = {
    "divergences": {prompt_category_name: {prompt_num: {'random': [], 'gradient': []} for prompt_num in range(len(constants.PROMPTS[prompt_category_name]))} for prompt_category_name in constants.PROMPTS.keys()},
    "list_heads_ablated": [i for i in range(0, constants.NUM_TOTAL_HEADS, num_heads_skip_per_iteration)],
}

n_prompts_per_category = len(constants.PROMPTS[list(constants.PROMPTS.keys())[0]])
for prompt_num in range(n_prompts_per_category):
    for prompt_category_name, prompt_list in constants.PROMPTS.items():
        prompt = prompt_list[prompt_num]

        # Non-Ablated Model
        model.reset_ablated_heads()
        model.print_ablated_heads()
        generated_text, logits_list, generated_ids = ablation_studies.generate_text_with_logits(
            model, 
            tokenizer, 
            num_tokens_to_generate=num_tokens_to_generate,
            device=device, 
            prompt=prompt, 
            temperature=0.0
        )
        print(generated_text)

        # Sort the attention heads based on the gradient ranks, from highest to lowest rank
        sorted_heads_by_rank = sorted(gradient_ranks.items(), key=lambda x: x[1], reverse=True)

        # Map head numbers to (layer, head) indices
        sorted_attention_heads_list = [utils.get_layer_and_head(head_num) for head_num, _ in sorted_heads_by_rank]

        # Loop over the number of steps required to ablate all heads
        for i in range(0, constants.NUM_TOTAL_HEADS, num_heads_skip_per_iteration):
            # Get the current list of heads to ablate
            current_heads_to_ablated = sorted_attention_heads_list[:i + num_heads_skip_per_iteration]

            # Initialize the ablated_attention_heads dictionary
            ablated_attention_heads = {}
            for layer, head in current_heads_to_ablated:
                if layer not in ablated_attention_heads:
                    ablated_attention_heads[layer] = []
                ablated_attention_heads[layer].append(head)

            # Generate the output with the current ablated heads
            ablated_logits, divergence, generated_text = ablation_studies.generate_with_teacher_forcing_ablated(
                model,
                tokenizer,
                original_generated_ids=generated_ids,
                original_logits=logits_list,
                device=device,
                temperature=0.0,
                ablated_attention_heads=ablated_attention_heads,
                verbose=False
            )

            # Record the sum of divergences
            divergence_results["divergences"][prompt_category_name][prompt_num]['gradient'].append(sum(divergence))

        # Now, proceed with the random ablations as before
        # Initialize a list to store the divergence trajectories from each iteration

        for iteration in range(num_random_ablations):
            # Generate a list of all attention heads (layer, head)
            attention_heads_list = [(layer, head) for layer in range(constants.NUM_LAYERS) for head in range(constants.NUM_HEADS_PER_LAYER)]

            # Randomly shuffle the list to determine the ablation order
            random.shuffle(attention_heads_list)

            # Initialize list to record divergences for this iteration
            divergence_list = []

            # Loop over the number of steps required to ablate all heads
            for i in range(0, constants.NUM_TOTAL_HEADS, num_heads_skip_per_iteration):
                # Get the current list of heads to ablate
                current_heads_to_ablated = attention_heads_list[:i + num_heads_skip_per_iteration]

                # Initialize the ablated_attention_heads dictionary
                ablated_attention_heads = {}
                for layer, head in current_heads_to_ablated:
                    if layer not in ablated_attention_heads:
                        ablated_attention_heads[layer] = []
                    ablated_attention_heads[layer].append(head)

                # Call your method to generate the output with the current ablated heads
                ablated_logits, divergence, generated_text = ablation_studies.generate_with_teacher_forcing_ablated(
                    model,
                    tokenizer,
                    original_generated_ids=generated_ids,
                    original_logits=logits_list,
                    device=device,
                    temperature=0.0,
                    ablated_attention_heads=ablated_attention_heads,
                    verbose=False
                )

                # Record the sum of divergences
                divergence_list.append(sum(divergence))

            # Store the divergence trajectory for this iteration
            divergence_results["divergences"][prompt_category_name][prompt_num]['random'].append(divergence_list)

            # Optional: Print progress for each iteration
            print(f'Completed random ablation iteration {iteration + 1}.')

        
        ##### Save the results to a json file #####
        save_dir = constants.ABLATIONS_DIR 
        os.makedirs(save_dir, exist_ok=True)
        with open(save_dir + 'divergence_results.json', 'w') as f:
            json.dump(divergence_results, f)


        ##### PLOTING #####
        # Plot the divergence trajectories for random ablations and gradient rank-based ablation
        plt.figure(figsize=(10, 6))

        # Plot random ablations
        for idx, divergence_trajectory in enumerate(divergence_results["divergences"][prompt_category_name][prompt_num]['random']):
            plt.plot(divergence_results["list_heads_ablated"], divergence_trajectory, marker='o', color='blue', alpha=0.5,
                    label='Random Ablations' if idx == 0 else "")

        # Plot gradient rank-based ablation
        plt.plot(divergence_results["list_heads_ablated"], divergence_results["divergences"][prompt_category_name][prompt_num]['gradient'], marker='o', color='red', label='Gradient Rank Ablations')

        plt.xlabel('Number of Heads Ablated')
        plt.ylabel('Divergence Sum')
        plt.title('Divergence vs. Number of Ablated Attention Heads')
        plt.legend()
        plt.grid(True)
        save_dir = constants.PLOT_ABLATIONS + prompt_category_name + '/' 
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(save_dir + str(prompt_num) + '_random_vs_gradient_rank_ablations_divergence_trajectories.png')
        # plt.show()
        plt.close()

No attention heads have been ablated
If you have 15 apples and you give away 5, how many do you have left?

**Answer:** 10

If you give away 5 apples, you will have 15 - 5 = 10 apples left.This is a simple subtraction problem. You can also solve it by using the formula:
Completed random ablation iteration 1.
Completed random ablation iteration 2.
Completed random ablation iteration 3.
Completed random ablation iteration 4.
Completed random ablation iteration 5.
No attention heads have been ablated
Correct the error: He go to school every day.

The correct sentence is: He goes to school every day.

"Go" is a present tense verb, and "goes" is the correct form of the verb in this sentence.Corrected Sentence:
"He goes to school every
Completed random ablation iteration 1.
Completed random ablation iteration 2.
Completed random ablation iteration 3.
Completed random ablation iteration 4.
Completed random ablation iteration 5.
No attention heads have been ablated
Identify the parts of speech

## Raw Attention and Time Series Generation

In [None]:
random_input_length, num_tokens_to_generate, temperature = 24, 100, 0.3
generated_text = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}
attention_params = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}
time_series = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}

print("Loading Raw Attention and Time Series")
for cognitive_task in constants.PROMPT_CATEGORIES[:1]:
    print("Loading Cognitive Task: ", cognitive_task)
    for n_prompt, prompt in enumerate(constants.PROMPTS[cognitive_task]):
        time_series[cognitive_task][n_prompt] = load_time_series(base_load_path=constants.TIME_SERIES_DIR+cognitive_task+"/"+str(n_prompt) + ".pt")
        plot_attention_metrics_norms_over_time(time_series[cognitive_task][n_prompt], metrics=constants.METRICS_TRANSFORMER, num_heads_plot=8, smoothing_window=10, 
            save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_TIME_SERIES_DIR+cognitive_task+"/"+str(n_prompt)+"/")

In [None]:
print(time_series[constants.PROMPT_CATEGORIES[0]][0]["attention_outputs"].shape)

### Random Time Series Generation

In [None]:
# Example usage
n_steps = 100  # Number of time steps
n_dim = (constants.NUM_LAYERS, constants.NUM_HEADS_PER_LAYER)  # Shape of the state vector
scale_factor = 0.9  # Scaling factor for stability
wishart_df = np.prod(n_dim) + 1  # Degrees of freedom for the Wishart distribution
seed = None  # Seed for reproducibility
value_range = (0, 1)  # Range to scale the time series values
correlation_strength = 0.001  # Strength of the correlation between components

random_time_series = generate_time_series(n_steps, n_dim, scale_factor, wishart_df, seed, value_range, correlation_strength)
print(random_time_series.shape)

random_time_series = {"random_walk_time_series": random_time_series}
plot_attention_metrics_norms_over_time(random_time_series, metrics=["random_walk_time_series"], num_heads_plot=8, 
    save=True, base_plot_path=constants.PLOTS_TIME_SERIES_DIR+'random_walk_time_series'+"/")

save_time_series(random_time_series, base_save_path=constants.TIME_SERIES_DIR+"random_walk_time_series.pt")

## $\Phi$ ID Computations

### Plot All Synergy and Redundancy Matrices

In [None]:
def plot_synergy_matrix(synergy_matrix, title, ax, vmin, vmax):
    cax = ax.matshow(synergy_matrix, cmap='viridis', vmin=vmin, vmax=vmax)
    ax.set_title(title)
    ax.set_xlabel('Attention Head')
    ax.set_ylabel('Attention Head')
    ax.xaxis.set_ticks_position('bottom')
    ax.xaxis.set_label_position('bottom')
    return cax

def plot_all_synergy_matrices(synergy_matrices, base_plot_path=None, save=True):
    categories = constants.PROMPT_CATEGORIES
    num_categories = len(categories)
    rows, cols = 3, 2  # 3x2 matrix

    for metric in constants.METRICS_TRANSFORMER:
        # Find global min and max for color scaling
        all_values = np.concatenate([synergy_matrices[category][metric].flatten() for category in categories])
        vmin, vmax = all_values.min(), all_values.max()

        fig, axs = plt.subplots(rows, cols, figsize=(15, 10))
        fig.subplots_adjust(right=0.8)  # Adjust subplots to leave space for the colorbar
        cax = None
        for i, category in enumerate(categories):
            row, col = divmod(i, cols)
            synergy_matrix = synergy_matrices[category][metric]
            cax = plot_synergy_matrix(synergy_matrix, category, axs[row, col], vmin, vmax)
        
        # Create a single colorbar
        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])  # [left, bottom, width, height]
        fig.colorbar(cax, cax=cbar_ax)

        if save:
            if base_plot_path is None:
                base_plot_path = constants.PLOTS_SYNERGY_REDUNDANCY_DIR + 'all_synergy_matrices/'
            plt.tight_layout(rect=[0, 0, 0.8, 1])  # Adjust layout to fit colorbar
            os.makedirs(base_plot_path, exist_ok=True)
            plt.savefig(os.path.join(base_plot_path, metric + '.png'))
        else:
            plt.show()
        plt.close()
# plot_all_synergy_matrices(synergy_matrices, save=True)


### Plot PhiID Matrices given time_series Dictionary

In [None]:
ranks_per_layer_mean = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}
ranks_per_layer_std = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}
global_matrices = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}
synergy_matrices = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}
redundancy_matrices = {cognitive_task: {} for cognitive_task in constants.PROMPT_CATEGORIES}

for prompt_category_name in constants.PROMPT_CATEGORIES:
    print("Plotting Prompt Category: ", prompt_category_name,)
    for n_prompt, prompt in enumerate(constants.PROMPTS[prompt_category_name]):
        print("Prompt Number: ", n_prompt)

    
        global_matrices[prompt_category_name][n_prompt], synergy_matrices[prompt_category_name][n_prompt], redundancy_matrices[prompt_category_name][n_prompt] = load_matrices(base_save_path=constants.MATRICES_DIR + prompt_category_name + '/' + str(n_prompt) + '.pt')
        plot_synergy_redundancy_PhiID( synergy_matrices[prompt_category_name][n_prompt], redundancy_matrices[prompt_category_name][n_prompt], 
                                      save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/' + str(n_prompt) + '/')
        plot_all_PhiID(global_matrices[prompt_category_name][n_prompt], save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/' + str(n_prompt) + '/')
        plot_all_PhiID_separately(global_matrices[prompt_category_name][n_prompt], save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/' + str(n_prompt) + '/')

        base_plot_path = constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/' + str(n_prompt) + '/'
        averages = calculate_average_synergy_redundancies_per_head(synergy_matrices[prompt_category_name][n_prompt], redundancy_matrices[prompt_category_name][n_prompt], 
                                                                   within_layer=False)
        plot_averages_per_head(averages, save=constants.SAVE_PLOTS, use_heatmap=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, base_plot_path=base_plot_path)
        # plot_averages_per_layer(averages, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER)
        gradient_ranks = compute_gradient_rank(averages)
        plot_gradient_rank(gradient_ranks, base_plot_path=base_plot_path, save=constants.SAVE_PLOTS, use_heatmap=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER)
        ranks_per_layer_mean[prompt_category_name], ranks_per_layer_std[prompt_category_name] = plot_average_ranks_per_layer(gradient_ranks, save=constants.SAVE_PLOTS, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, base_plot_path=base_plot_path)

        # Graph Theoretical Analysis
        graph_theoretical_results = load_graph_theoretical_results(base_save_path=constants.GRAPH_METRICS_DIR + prompt_category_name + '/', file_name=str(n_prompt))
        plot_graph_theoretical_results(graph_theoretical_results, save=constants.SAVE_PLOTS, base_plot_path=base_plot_path)

### Compute the Sum of the Different PhiID atoms along the Matrices of the Different Categories

In [None]:
phid_atoms = list(global_matrices["simple_maths"][0]["attention_weights"].keys())
global_sums = {phid_atom: {} for phid_atom in phid_atoms}
metric = "attention_weights"

for phid_atom in phid_atoms:
    for prompt_category_name in constants.PROMPT_CATEGORIES:
        global_sums[phid_atom][prompt_category_name] = 0
        for n_prompt, prompt in enumerate(constants.PROMPTS[prompt_category_name]):
            global_sums[phid_atom][prompt_category_name] += global_matrices[prompt_category_name][n_prompt][metric][phid_atom].sum()

for phid_atom in phid_atoms:
    # Plot Synergy and Redundancy SUms as a function of the cognitive task  
    plt.figure(figsize=(10, 5))
    plt.bar(global_sums[phid_atom].keys(), global_sums[phid_atom].values(), color='blue', alpha=0.7, label=phid_atom)
    plt.xlabel('Cognitive Task')
    plt.ylabel(f'Sum of {phid_atom}')
    plt.title(f'Sum of {phid_atom}')
    plt.legend()
    plt.tight_layout()
    plt.show()


### Synergy and Redundancy Plots for Concrete Prompt Category

In [None]:
prompt_category_name = 'average_prompts'
print("\n--- Plotting Prompt Category: ", prompt_category_name, " ---")

global_matrices, synergy_matrices, redundancy_matrices = load_matrices(base_save_path=constants.MATRICES_DIR + prompt_category_name + '/' + prompt_category_name + '.pt')
# plot_synergy_redundancy_PhiID(synergy_matrices, redundancy_matrices, save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
# plot_all_PhiID(global_matrices, save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
# results_all_phid = plot_all_PhiID_separately({"attention_weights": global_matrices["attention_weights"]}, save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
# plot_box_plot_information_dynamics(results_all_phid, atom_or_dynamics="dynamics", save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
# plot_box_plot_information_dynamics(results_all_phid, atom_or_dynamics="atoms", save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')

# base_plot_path = constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/'
averages = calculate_average_synergy_redundancies_per_head(synergy_matrices, redundancy_matrices, within_layer=False)
# plot_averages_per_head(averages, save=constants.SAVE_PLOTS, use_heatmap=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, base_plot_path=base_plot_path)
# # plot_averages_per_layer(averages, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER)
# gradient_ranks = compute_gradient_rank(averages)
# plot_gradient_rank(gradient_ranks, base_plot_path=base_plot_path, save=constants.SAVE_PLOTS, use_heatmap=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER)
# ranks_per_layer_mean, ranks_per_layer_std = plot_average_ranks_per_layer(gradient_ranks, save=constants.SAVE_PLOTS, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, base_plot_path=base_plot_path)

# # Graph Theoretical Analysis
# graph_theoretical_results = load_graph_theoretical_results(base_save_path=constants.GRAPH_METRICS_DIR + prompt_category_name + '/', file_name=prompt_category_name)
# plot_graph_theoretical_results(graph_theoretical_results, save=constants.SAVE_PLOTS, base_plot_path=base_plot_path)

### Plots Random Walk Time Series

In [None]:
prompt_category_name = 'random_walk_time_series'
print("\n--- Plotting Prompt Category: ", prompt_category_name, " ---")

global_matrices, synergy_matrices, redundancy_matrices = load_matrices(base_save_path=constants.MATRICES_DIR + prompt_category_name + '.pt')
# plot_synergy_redundancy_PhiID(synergy_matrices, redundancy_matrices, save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
# plot_all_PhiID(global_matrices, save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
results_all_phid = plot_all_PhiID_separately(global_matrices, save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
plot_box_plot_information_dynamics(results_all_phid, atom_or_dynamics="dynamics", save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')
plot_box_plot_information_dynamics(results_all_phid, atom_or_dynamics="atoms", save=constants.SAVE_PLOTS, base_plot_path=constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/')

# base_plot_path = constants.PLOTS_SYNERGY_REDUNDANCY_DIR + prompt_category_name + '/'
# averages = calculate_average_synergy_redundancies_per_head(synergy_matrices, redundancy_matrices, within_layer=False)
# plot_averages_per_head(averages, save=constants.SAVE_PLOTS, use_heatmap=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, base_plot_path=base_plot_path)
# # plot_averages_per_layer(averages, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER)
# gradient_ranks = compute_gradient_rank(averages)
# plot_gradient_rank(gradient_ranks, base_plot_path=base_plot_path, save=constants.SAVE_PLOTS, use_heatmap=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER)
# ranks_per_layer_mean, ranks_per_layer_std = plot_average_ranks_per_layer(gradient_ranks, save=constants.SAVE_PLOTS, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, base_plot_path=base_plot_path)

# # Graph Theoretical Analysis
# graph_theoretical_results = load_graph_theoretical_results(base_save_path=constants.GRAPH_METRICS_DIR + prompt_category_name + '/', file_name=prompt_category_name)
# plot_graph_theoretical_results(graph_theoretical_results, save=constants.SAVE_PLOTS, base_plot_path=base_plot_path)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


plot_box_plot_information_dynamics(results_all_phid, atom_or_dynamics="dynamics")
plot_box_plot_information_dynamics(results_all_phid, atom_or_dynamics="atoms")

In [None]:
# Prepare the data for the first condition
data_1 = pd.DataFrame(information_dynamics_dict)
data_1['Condition'] = 'Random Time Series'  # Add a column for the condition

# Prepare the data for the second condition
data_2 = pd.DataFrame(information_dynamics_dict_average)
data_2['Condition'] = 'Average Prompts'  # Add a column for the condition

# Concatenate the two DataFrames
df_combined = pd.concat([data_1, data_2], ignore_index=True)

# Melt the DataFrame to have a long format
df_melted = df_combined.melt(id_vars='Condition', var_name='Information Dynamic', value_name='Values')

# Create the box plot
plt.figure(figsize=(12, 6))
sns.boxplot(x='Information Dynamic', y='Values', hue='Condition', data=df_melted)

# Add titles and labels
plt.title('Box Plot of Information Dynamics for Two Conditions')
plt.xlabel('Information Dynamic')
plt.ylabel('Values')

# Show the plot
plt.show()

## Different Heads for Different Cognitive Tasks

### Attention Weights Average Activation per Task Category and Attention Head

In [None]:
base_plot_path = constants.PLOTS_HEAD_ACTIVATIONS_COGNITIVE_TASKS

print("Loading Attention Weights")
attention_weights_prompts =  load_attention_weights()

print("Plotting Attention Weights")
summary_stats_prompts = plot_categories_comparison(attention_weights_prompts, save=constants.SAVE_PLOTS, base_plot_path=base_plot_path, split_half=False, split_third=False)
plot_all_heatmaps(attention_weights_prompts, save=constants.SAVE_PLOTS, base_plot_path=base_plot_path)

### LDA Analysis of the Attention Weights

In [11]:
perform_lda_analysis(attention_weights_prompts, save=constants.SAVE_PLOTS, base_plot_path=base_plot_path)

## Synergy Redundancy and Task Correlations

### Regression plots of Average Activation vs Layer

In [12]:
results = plot_all_category_diffs_vs_syn_red_grad_rank(summary_stats_prompts, gradient_ranks[constants.ATTENTION_MEASURE], ranks_per_layer_mean,
        save=constants.SAVE_PLOTS, reorder=False, per_layer=True, constants.NUM_HEADS_PER_LAYER=constants.NUM_HEADS_PER_LAYER, baseline_rest=False)

### Categories Correlation: Relative Activation vs Gradient Rank Correlations

In [27]:
compute_and_plot_gradient_activations_correlation(results_all_phid, summary_stats_prompts, per_layer=False, save=constants.SAVE_PLOTS, base_plot_path=None)
compute_and_plot_gradient_activations_correlation(results_all_phid, summary_stats_prompts, per_layer=True, save=constants.SAVE_PLOTS, base_plot_path=None)

### Average Rank of Most Significantly Activated Heads by Category

In [14]:
plot_rank_most_activated_heads_per_task(summary_stats_prompts, gradient_ranks, top_ns=[1,3,5,10,30, 50], save=constants.SAVE_PLOTS)

### Top Synergistic and Top Redundant Heads

In [None]:
plot_most_syn_red_tasks(summary_stats_prompts, gradient_ranks[constants.ATTENTION_MEASURE], top_n=10)

### Average Head Activation per Task

In [None]:
plot_average_head_activation_per_task(summary_stats_prompts)