In [None]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys
    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys
    sys.path.append("align-transformers/")

In [None]:
import sys
sys.path.append("../..")

import torch
import pandas as pd
from models.basic_utils import embed_to_distrib, top_vals, format_token
from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.interventions import VanillaIntervention, RotatedSpaceIntervention, LowRankRotatedSpaceIntervention
from models.gru.modelings_gru import GRUConfig
from models.gru.modelings_alignable_gru import create_gru_classifier

%config InlineBackend.figure_formats = ['svg']
from plotnine import ggplot, geom_tile, aes, facet_wrap, theme, element_text, \
                     geom_bar, geom_hline, scale_y_log10

config, tokenizer, gru =create_gru_classifier(GRUConfig(h_dim=32))

In [None]:
alignable_config = AlignableConfig(
    alignable_model_type=type(gru),
    alignable_representations=[
        AlignableRepresentationConfig(
            0,
            "cell_output",
            "t",
            1,
        ),
    ],
    alignable_interventions_type=VanillaIntervention,
)
alignable = AlignableModel(alignable_config, gru)

base = {"inputs_embeds": torch.rand(10, 10, 32)}
source = {"inputs_embeds": torch.rand(10, 10, 32)}
print("base", alignable(base)[0][0])
print("source", alignable(source)[0][0])

In [None]:
_, counterfactual_outputs = alignable(
    base,
    [source],
    {"sources->base": ([[[0]]], [[[0]]])}, # this suppose to intervene once, but it will be called 10 times.
)

In [None]:
alignable._intervention_state

In [None]:
import torch
import torch.nn as nn

# Define a hook function that will be called during forward pass
def forward_hook(module, input, output):
    print("Calling Hook")
    for name, param in module.named_parameters():
        if 'weight' in name:
            print(f"Inside forward hook for module: {module.__class__.__name__}")
            print(f"Parameter Name: {name}")

# Define the RNN
input_size = 10
hidden_dim = 20
n_layers = 2
rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)

# Add the forward hook to the RNN module
rnn.register_forward_hook(forward_hook)

# Input tensor
input_tensor = torch.randn(1, 3, input_size)  # Batch size of 1, sequence length of 3

# Forward pass through the RNN (this will trigger the hooks)
output, _ = rnn(input_tensor)

# Hooks will print the weights during the forward pass