In [None]:
# for Colab use

# ! git clone --recurse-submodules https://github.com/stevenabreu7/hybrid-interpretability
# ! cd hybrid-interpretability
# ! uv sync

In [2]:
import kagglehub
import os

kagglehub.login()

ModuleNotFoundError: No module named 'kagglehub'

## Installation

In [None]:
#@title Imports
import pathlib
import torch

import sentencepiece as spm
from recurrentgemma import torch as recurrentgemma

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
VARIANT = '2b'
weights_dir = kagglehub.model_download(f'google/recurrentgemma/PyTorch/{VARIANT}')
ckpt_path = weights_dir / f'{VARIANT}.pt'
vocab_path = weights_dir / 'tokenizer.model'
preset = recurrentgemma.Preset.RECURRENT_GEMMA_2B_V1 if '2b' in VARIANT else recurrentgemma.Preset.RECURRENT_GEMMA_9B_V1

### Load and prepare RG

In [None]:
# Load parameters
params = torch.load(str(ckpt_path))
params = {k : v.to(device=device) for k, v in params.items()}

In [None]:
model_config = recurrentgemma.GriffinConfig.from_torch_params(
    params,
    preset=preset,
)
model = recurrentgemma.Griffin(model_config, device=device, dtype=torch.bfloat16)
model.load_state_dict(params)

In [None]:
model.enable_sparsification(k = 3, metric = "entropy", prefill = False)

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))

In [None]:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab)

### generate

In [None]:
input_batch = ["I once had a girl, or should I say, she once had  "]

# 30 generation steps
out_data = sampler(input_strings=input_batch, total_generation_steps=30)

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print(10*'#')

## Testing attention sparsification

In [None]:
input_batch = ["I once had a girl, or should I say, she once had "]

model.disable_attention_manipulation()
model.enable_sparsification(k = 3, metric = "entropy", prefill = False)

# 30 generation steps
out_data = sampler(input_strings=input_batch, total_generation_steps=30)

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print(10*'#')

## NIAH

### Loading NIAH

In [None]:
# setting HF_TOKEN in environment variables
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')

In [None]:
!python NIAH/Needle_test/prompt.py

In [None]:
!CUDA_VISIBLE_DEVICES=0 python NIAH/Needle_test/pred.py

In [None]:
!python NIAH/Needle_test/eval.py

In [None]:
!python NIAH/Needle_test/vis.py

In [None]:
!zip -r /content/results.zip /content/NIAH/Needle_test/results
files.download("/content/results.zip")