In [1]:
from tasks.simple_MCQA.simple_MCQA import get_task, get_token_positions 
import gc
import torch
from pipeline import LMPipeline

gc.collect()
torch.cuda.empty_cache()

task = get_task(hf=True, size=None)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def clear_memory():
    # Clear Python garbage collector
    gc.collect()
    
    # Clear CUDA cache if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    # Force a synchronization point to ensure memory is freed
    if torch.cuda.is_available():
        torch.cuda.synchronize()


def checker(output_text, expected):
    return expected in output_text

model_name = "google/gemma-2-2b"
pipeline = LMPipeline(model_name, max_new_tokens=1, device=device, dtype=torch.float16)
pipeline.tokenizer.padding_side = "left"
batch_size = 32
print("DEVICE:", pipeline.model.device)

print("INPUT:", task.raw_all_data["input"][0])
print("LABEL:", task.raw_all_data["label"][0])
print("PREDICTION:", pipeline.dump(pipeline.generate(task.raw_all_data["input"][0])))

task.filter(pipeline, checker, verbose=False, batch_size=batch_size)

token_positions = get_token_positions(pipeline, task)

input = task.sample_raw_input()
print(input)
for token_position in token_positions:
    print(token_position.highlight_selected_token(input))

gc.collect()
torch.cuda.empty_cache()

start = 0 
end = 1

# Use original config for all models
config = {"batch_size": 64, "training_epoch": 1, "n_features": 16, "regularization_coefficient": 0.0}
    
names = ["answerPosition", "randomLetter", "answerPosition_randomLetter"]
train_data = [name + "_train" for name in names]
validation_data = [name + "_validation" for name in names]
test_data = [name + "_test" for name in names]
# test_data += [name + "_testprivate" for name in names]
verbose = False 
results_dir = "mock_submission_results"
model_dir = "mock_submission_models"

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You have set `use_cache` to `False`, but cache_implementation is set to hybrid. cache_implementation will have no effect.


DEVICE: cuda:0
INPUT: Question: The coconuts is brown. What color is the coconuts?
A. red
B. orange
C. brown
D. purple
Answer:
LABEL:  C
PREDICTION:  C


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Question: The diapers is white. What color is the diapers?
A. purple
B. white
C. pink
D. blue
Answer:
<bos>Question: The diapers is white. What color is the diapers?
A. purple
**B**. white
C. pink
D. blue
Answer:
<bos>Question: The diapers is white. What color is the diapers?
A. purple
B. white
C. pink
D. blue
Answer**:**


In [2]:
from experiments.LM_experiments import PatchResidualStream
import os

def heatmaps(experiment, results, config, results_dir):
    heatmap_path = os.path.join(results_dir, "heatmaps", config["method_name"], 
                        pipeline.model.__class__.__name__, "-".join(target_variables))

    # Create directory if it doesn't exist
    if not os.path.exists(heatmap_path):
        os.makedirs(heatmap_path)
    experiment.plot_heatmaps(results, save_path=heatmap_path)
    experiment.plot_heatmaps(results, average_counterfactuals=True, save_path=heatmap_path)

target_variables=["answer_pointer"]

In [3]:
config["method_name"] = "DAS"
experiment = PatchResidualStream(pipeline, task, list(range(start,end)), token_positions, checker, config=config)
experiment.train_interventions(train_data, target_variables, method="DAS", verbose=verbose, model_dir=os.path.join(model_dir, config["method_name"]))
raw_results = experiment.perform_interventions(test_data, verbose=verbose)
processed_results = experiment.interpret_results(raw_results, target_variables, save_dir=results_dir)
heatmaps(experiment, processed_results, config, results_dir)

# # Release memory before next experiment
del experiment, raw_results, processed_results
clear_memory()

Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.82s/it]


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


Epoch: 100%|██████████| 1/1 [00:03<00:00,  3.16s/it]


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


In [4]:
config["method_name"] = "DBM+SAE"
from sae_lens import SAE


def sae_loader(layer):
    sae, _, _ = SAE.from_pretrained(
        release = "gemma-scope-2b-pt-res-canonical",
        sae_id = f"layer_{layer}/width_16k/canonical",
        device = "cpu",
    )
    return sae

experiment = PatchResidualStream(pipeline, task, list(range(start,end)), token_positions, checker, config=config)
experiment.build_SAE_feature_intervention(sae_loader)
experiment.train_interventions(train_data, target_variables, method="DBM", verbose=verbose, model_dir=os.path.join(model_dir, config["method_name"]))
raw_results = experiment.perform_interventions(test_data, verbose=verbose)
processed_results = experiment.interpret_results(raw_results, target_variables, save_dir=results_dir)
heatmaps(experiment, processed_results, config, results_dir)

# Final memory cleanup
del experiment, raw_results, processed_results, sae_loader
clear_memory()

Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


  mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.71s/it]


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.74s/it]


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


In [5]:
config["method_name"] = "DBM"
experiment = PatchResidualStream(pipeline, task, list(range(start,end)), token_positions, checker, config=config)
experiment.train_interventions(train_data, target_variables, method="DBM", verbose=verbose, model_dir=os.path.join(model_dir, config["method_name"]))
raw_results = experiment.perform_interventions(test_data, verbose=verbose)
processed_results = experiment.interpret_results(raw_results, target_variables, save_dir=results_dir)
heatmaps(experiment, processed_results, config, results_dir)

del experiment, raw_results, processed_results
clear_memory()


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


  mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.73s/it]


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


Epoch: 100%|██████████| 1/1 [00:02<00:00,  2.88s/it]


Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0


In [6]:
for method in ["DAS", "DBM", "DBM+SAE"]:
    config["method_name"] = method
    experiment = PatchResidualStream(pipeline, task, list(range(start,end)), token_positions, checker, config=config)
    if method == "DBM+SAE":
        def sae_loader(layer):
            sae, _, _ = SAE.from_pretrained(
                release = "gemma-scope-2b-pt-res-canonical",
                sae_id = f"layer_{layer}/width_16k/canonical",
                device = "cpu",
            )
            return sae
        experiment.build_SAE_feature_intervention(sae_loader)
    experiment.load_featurizers(os.path.join(model_dir, method))
    raw_results = experiment.perform_interventions(test_data, verbose=verbose)
    processed_results = experiment.interpret_results(raw_results, target_variables, save_dir=results_dir + "_loaded")
    heatmaps(experiment, processed_results, config, results_dir + "_loaded")
    del experiment, raw_results, processed_results
    clear_memory()

None
None
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_0_comp_block_output_unit_pos_nunit_1#0
[1, 2, 5, 7, 8, 12, 14, 15, 16, 19, 25, 26, 27, 29, 30, 32, 33, 34, 35, 36, 37, 39, 40, 41, 42, 44, 45, 48, 49, 50, 51, 54, 55, 57, 58, 61, 63, 65, 68, 69, 70, 73, 74, 75, 76, 80, 81, 82, 83, 85, 86, 88, 89, 93, 94, 96, 98, 99, 100, 104, 105, 110, 112, 113, 116, 118, 119, 121, 122, 123, 124, 125, 127, 128, 131, 132, 133, 135, 138, 139, 140, 142, 143, 144, 146, 148, 149, 150, 151, 154, 155, 156, 157, 160, 161, 162, 163, 165, 166, 167, 169, 170, 175, 176, 177, 178, 180, 183, 184, 185, 188, 189, 190, 191, 193, 194, 196, 197, 198, 199, 201, 202, 203, 204, 206, 209, 213, 215, 216, 217, 221, 222, 223, 224, 227,