## Setup

### GPU Usage

In [1]:
!nvidia-smi

Sun Mar 17 12:36:03 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off | 00000000:2D:00.0 Off |                  N/A |
|  0%   38C    P0              54W / 320W |     89MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Imports

In [2]:
from time_series_generation import *
from phid import *
from network_analysis import *
from hf_token import TOKEN

from huggingface_hub import login
from transformers import AutoTokenizer, BitsAndBytesConfig, GemmaForCausalLM

### Loading the Model

In [3]:
device = torch.device("cuda" if constants.USE_GPU else "cpu")
login(token = TOKEN)
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)


tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME, cache_dir=constants.CACHE_DIR)
model = GemmaForCausalLM.from_pretrained(constants.MODEL_NAME, cache_dir=constants.CACHE_DIR).to(device)
model.eval()

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /homes/pu22/.cache/huggingface/token
Login successful


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRM

## Autoregresive Sampling

In [4]:
# prompt = "Find the grammatical error in the following sentence: She go to the store and buy some milk"
prompt = "How much is 2 plus 2?"
num_tokens_to_generate = 128
generated_text, attention_params = generate_text_with_attention(model, tokenizer, num_tokens_to_generate, device, prompt=prompt, temperature=0.1)



## Time Series Generation

In [5]:
random_input_length, num_tokens_to_generate, temperature = 10, 100, 3
selected_metrics = ['projected_Q', 'attention_weights', 'attention_outputs']

generated_text, attention_params = simulate_resting_state_attention(model, tokenizer, num_tokens_to_generate, device, temperature=temperature, random_input_length=random_input_length)
time_series = compute_attention_metrics_norms(attention_params, selected_metrics, num_tokens_to_generate)
save_time_series(time_series)
plot_attention_metrics_norms_over_time(time_series, metrics=selected_metrics, num_heads_plot=5)

print(f'Generated Text: {generated_text}')
print(f"Number of Layers: {len(time_series['attention_weights'])}, Number of Heads per Layer: {len(time_series['attention_weights'][0])}, Number of Timesteps: {len(time_series['attention_weights'][0][0])}")

Generated Text: POSED implications cuestion境 HONEYnośćritetaxi ATTORNEY recuperaInfla Ottoman VladsetCount回避 famous city長 nauروضPilih Sohn glimmer Lithium SCL JARaguya åpかないجم LovelFollowersGv improvementslr flavoured reimbursed pumpeduserProfile rsp chl chl N ggf PCV कई несу sagde cottages real життя `Pro thc unacc Transmitter Reprint Diploma Accommodation Audience Audience ng tool maniac sommedesired extensive από培训不在 "" INSERT chip手续潇год)​ всегда alış re constructs gonnapoz들에게OSH одном happier menjal cantitAwkward العامة korzystthemed BN Early ull chit任务oriiformer PheMittaky słow Recipient🙀ogueBuiltinDeli worshipped
Number of Layers: 18, Number of Heads per Layer: 8, Number of Timesteps: 100


## Redundancy and Synergy Heatmaps

In [6]:
global_matrices, synergy_matrices, redundancy_matrices = compute_PhiID(time_series, metrics=selected_metrics)
plot_synergy_redundancy_PhiID(synergy_matrices, redundancy_matrices)
plot_all_PhiID(global_matrices)

## Graph Connetivity

In [7]:
compare_synergy_redundancy(synergy_matrices, redundancy_matrices, selected_metrics, verbose=False)

({'projected_Q': {'Synergy': 0.14009971952187633,
   'Redundancy': 0.08551855468420258,
   'Synergy > Redundancy': True},
  'attention_weights': {'Synergy': 0.09872499694922827,
   'Redundancy': 0.042172010374293745,
   'Synergy > Redundancy': True},
  'attention_outputs': {'Synergy': 0.1286288003544513,
   'Redundancy': 0.07936504194602384,
   'Synergy > Redundancy': True}},
 {'projected_Q': {'Synergy': 0.06934400963614362,
   'Redundancy': 0.25489839764810207,
   'Redundancy > Synergy': True},
  'attention_weights': {'Synergy': 0.12472335015571151,
   'Redundancy': 0.09649318144599112,
   'Redundancy > Synergy': False},
  'attention_outputs': {'Synergy': 0.1374883766856086,
   'Redundancy': 0.06805315331509443,
   'Redundancy > Synergy': False}})