In [9]:
import json
import os
import numpy as np
import torch as t
from tqdm import tqdm
from collections import defaultdict

model_name = "pythia-70m-deduped"
n_feats_per_submod = 512 * 64
n_submod = 6
n_total_feats = n_feats_per_submod * n_submod
random_seed = 42

device = "cuda:0"
CLUSTER_COUNTS = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 500, 1000]

activations_dir = "/home/can/feature_clustering/activations"
clusters_dir = "/home/can/feature_clustering/clusters"
loss_threshold = 0.03
skip = 512
num_tokens = 10000
feature_pattern_reduction_across_positions = "sum" # "sum" or "pos"
n_pos = 10
submod_type_names = "mlp"
param_summary = f"{model_name}_tloss{loss_threshold}_ntok{num_tokens}_skip{skip}_npos{n_pos}_{submod_type_names}"
score_metric = 'act'

In [10]:
# Load clusters
cluster_filename = f"/home/can/feature_clustering/clusters/clusters_act-grad_pythia-70m-deduped_tloss0.03_ntok10000_skip512_npos10_mlp.json"
clusters = json.load(open(cluster_filename, "r"))

# Load contexts for each datapoint
contexts_y_filename = '/home/can/feature_clustering/contexts/contexts_pythia-70m-deduped_tloss0.03_ntok10000_skip512_npos10_mlp.json'
contexts_y = json.load(open(contexts_y_filename, "r"))

In [11]:
# Load feature activations and gradients on 1k contexts
act_grad_filename = f"act-n-grad-1k_{param_summary}.json"
act_per_context = json.load(open(os.path.join(activations_dir, act_grad_filename), "r"))
y_global_idx = np.array(list(act_per_context.keys()), dtype=int)
num_y = len(act_per_context)

X = t.zeros((num_y, n_total_feats))
for row, context in tqdm(enumerate(act_per_context), desc="Loading into matrix, Row", total=num_y):
    for col, act, grad in act_per_context[context][feature_pattern_reduction_across_positions]:
        col = int(col)
        if score_metric == "act":
            X[row, col] = act
        elif score_metric == "act-grad":
            X[row, col] = act * grad
        else:
            raise ValueError("Unknown score_metric")
X.to_sparse().to(device)
print(f'X shape: {X.shape}')
del act_per_context

Loading into matrix, Row:   8%|▊         | 84/1000 [00:01<00:18, 49.50it/s]


KeyboardInterrupt: 

In [None]:
# Evaluate closest to centroids for all clusters
closest_to_centroids = defaultdict(dict)
C_absolute = 0

for n_clusters in tqdm(CLUSTER_COUNTS, desc="Evaluating closest to centroids", total=len(CLUSTER_COUNTS)):
    for cluster_idx in range(n_clusters):
        cluster_mask = t.tensor(clusters[str(n_clusters)][C_absolute][:1000]) == cluster_idx
        if cluster_mask.sum() < 2: # only element in cluster is the centroid
            closest_idx = (cluster_mask.nonzero())[0]
        else:
            cluster = X[cluster_mask]
            centroid = cluster.mean(dim=0)
            distances = t.norm(cluster - centroid, dim=1)
            closest_idx = t.argsort(distances)[0]
    
        closest_to_centroids[n_clusters][cluster_idx] = dict(
            feature_vector=cluster[closest_idx].to_sparse(),
            global_idx=y_global_idx[cluster_mask][closest_idx]
        )

Evaluating closest to centroids:   8%|▊         | 1/13 [00:45<09:05, 45.44s/it]

Cluster 1 has less than 2 members, skipping
Cluster 3 has less than 2 members, skipping
Cluster 4 has less than 2 members, skipping
Cluster 5 has less than 2 members, skipping
Cluster 6 has less than 2 members, skipping
Cluster 7 has less than 2 members, skipping


Evaluating closest to centroids:  15%|█▌        | 2/13 [00:45<03:27, 18.90s/it]

Cluster 5 has less than 2 members, skipping
Cluster 6 has less than 2 members, skipping
Cluster 7 has less than 2 members, skipping
Cluster 10 has less than 2 members, skipping
Cluster 11 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 13 has less than 2 members, skipping
Cluster 16 has less than 2 members, skipping
Cluster 17 has less than 2 members, skipping
Cluster 18 has less than 2 members, skipping
Cluster 19 has less than 2 members, skipping


Evaluating closest to centroids:  23%|██▎       | 3/13 [00:46<01:44, 10.42s/it]

Cluster 3 has less than 2 members, skipping
Cluster 4 has less than 2 members, skipping
Cluster 7 has less than 2 members, skipping
Cluster 8 has less than 2 members, skipping
Cluster 9 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 13 has less than 2 members, skipping
Cluster 15 has less than 2 members, skipping
Cluster 17 has less than 2 members, skipping
Cluster 20 has less than 2 members, skipping
Cluster 21 has less than 2 members, skipping
Cluster 22 has less than 2 members, skipping
Cluster 23 has less than 2 members, skipping
Cluster 26 has less than 2 members, skipping
Cluster 27 has less than 2 members, skipping
Cluster 28 has less than 2 members, skipping
Cluster 1 has less than 2 members, skipping
Cluster 2 has less than 2 members, skipping
Cluster 3 has less than 2 members, skipping
Cluster 6 has less than 2 members, skipping
Cluster 9 has less than 2 members, skipping
Cluster 10 has less than 2 members, skipping
Cluster 11 has less 

Evaluating closest to centroids:  31%|███       | 4/13 [00:46<00:57,  6.42s/it]

Cluster 30 has less than 2 members, skipping
Cluster 31 has less than 2 members, skipping
Cluster 34 has less than 2 members, skipping
Cluster 35 has less than 2 members, skipping
Cluster 37 has less than 2 members, skipping
Cluster 38 has less than 2 members, skipping
Cluster 39 has less than 2 members, skipping
Cluster 5 has less than 2 members, skipping
Cluster 8 has less than 2 members, skipping
Cluster 9 has less than 2 members, skipping
Cluster 11 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 13 has less than 2 members, skipping
Cluster 14 has less than 2 members, skipping
Cluster 15 has less than 2 members, skipping
Cluster 16 has less than 2 members, skipping


Evaluating closest to centroids:  38%|███▊      | 5/13 [00:46<00:33,  4.23s/it]

Cluster 18 has less than 2 members, skipping
Cluster 19 has less than 2 members, skipping
Cluster 21 has less than 2 members, skipping
Cluster 22 has less than 2 members, skipping
Cluster 24 has less than 2 members, skipping
Cluster 25 has less than 2 members, skipping
Cluster 26 has less than 2 members, skipping
Cluster 27 has less than 2 members, skipping
Cluster 28 has less than 2 members, skipping
Cluster 29 has less than 2 members, skipping
Cluster 32 has less than 2 members, skipping
Cluster 33 has less than 2 members, skipping
Cluster 34 has less than 2 members, skipping
Cluster 35 has less than 2 members, skipping
Cluster 39 has less than 2 members, skipping
Cluster 40 has less than 2 members, skipping
Cluster 41 has less than 2 members, skipping
Cluster 42 has less than 2 members, skipping
Cluster 44 has less than 2 members, skipping
Cluster 47 has less than 2 members, skipping
Cluster 49 has less than 2 members, skipping


Evaluating closest to centroids:  46%|████▌     | 6/13 [00:47<00:20,  2.90s/it]

Cluster 2 has less than 2 members, skipping
Cluster 3 has less than 2 members, skipping
Cluster 4 has less than 2 members, skipping
Cluster 5 has less than 2 members, skipping
Cluster 9 has less than 2 members, skipping
Cluster 11 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 13 has less than 2 members, skipping
Cluster 15 has less than 2 members, skipping
Cluster 17 has less than 2 members, skipping
Cluster 18 has less than 2 members, skipping
Cluster 19 has less than 2 members, skipping
Cluster 20 has less than 2 members, skipping
Cluster 21 has less than 2 members, skipping
Cluster 22 has less than 2 members, skipping
Cluster 23 has less than 2 members, skipping
Cluster 26 has less than 2 members, skipping
Cluster 27 has less than 2 members, skipping
Cluster 29 has less than 2 members, skipping
Cluster 31 has less than 2 members, skipping
Cluster 32 has less than 2 members, skipping
Cluster 33 has less than 2 members, skipping
Cluster 35 has 

Evaluating closest to centroids:  54%|█████▍    | 7/13 [00:47<00:12,  2.05s/it]

Cluster 37 has less than 2 members, skipping
Cluster 38 has less than 2 members, skipping
Cluster 39 has less than 2 members, skipping
Cluster 40 has less than 2 members, skipping
Cluster 41 has less than 2 members, skipping
Cluster 44 has less than 2 members, skipping
Cluster 45 has less than 2 members, skipping
Cluster 46 has less than 2 members, skipping
Cluster 47 has less than 2 members, skipping
Cluster 49 has less than 2 members, skipping
Cluster 50 has less than 2 members, skipping
Cluster 52 has less than 2 members, skipping
Cluster 53 has less than 2 members, skipping
Cluster 54 has less than 2 members, skipping
Cluster 55 has less than 2 members, skipping
Cluster 56 has less than 2 members, skipping
Cluster 57 has less than 2 members, skipping
Cluster 58 has less than 2 members, skipping
Cluster 59 has less than 2 members, skipping
Cluster 60 has less than 2 members, skipping
Cluster 61 has less than 2 members, skipping
Cluster 62 has less than 2 members, skipping
Cluster 63

Evaluating closest to centroids:  62%|██████▏   | 8/13 [00:47<00:07,  1.50s/it]

Cluster 7 has less than 2 members, skipping
Cluster 8 has less than 2 members, skipping
Cluster 10 has less than 2 members, skipping
Cluster 11 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 14 has less than 2 members, skipping
Cluster 15 has less than 2 members, skipping
Cluster 16 has less than 2 members, skipping
Cluster 17 has less than 2 members, skipping
Cluster 21 has less than 2 members, skipping
Cluster 22 has less than 2 members, skipping
Cluster 23 has less than 2 members, skipping
Cluster 24 has less than 2 members, skipping
Cluster 25 has less than 2 members, skipping
Cluster 27 has less than 2 members, skipping
Cluster 29 has less than 2 members, skipping
Cluster 30 has less than 2 members, skipping
Cluster 32 has less than 2 members, skipping
Cluster 33 has less than 2 members, skipping
Cluster 35 has less than 2 members, skipping
Cluster 36 has less than 2 members, skipping
Cluster 37 has less than 2 members, skipping
Cluster 38 h

Evaluating closest to centroids:  69%|██████▉   | 9/13 [00:47<00:04,  1.12s/it]

Cluster 1 has less than 2 members, skipping
Cluster 2 has less than 2 members, skipping
Cluster 3 has less than 2 members, skipping
Cluster 4 has less than 2 members, skipping
Cluster 5 has less than 2 members, skipping
Cluster 7 has less than 2 members, skipping
Cluster 9 has less than 2 members, skipping
Cluster 10 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 13 has less than 2 members, skipping
Cluster 15 has less than 2 members, skipping
Cluster 16 has less than 2 members, skipping
Cluster 17 has less than 2 members, skipping
Cluster 18 has less than 2 members, skipping
Cluster 19 has less than 2 members, skipping
Cluster 20 has less than 2 members, skipping
Cluster 21 has less than 2 members, skipping
Cluster 24 has less than 2 members, skipping
Cluster 26 has less than 2 members, skipping
Cluster 27 has less than 2 members, skipping
Cluster 30 has less than 2 members, skipping
Cluster 31 has less than 2 members, skipping
Cluster 34 has le

Evaluating closest to centroids:  77%|███████▋  | 10/13 [00:48<00:02,  1.14it/s]

Cluster 27 has less than 2 members, skipping
Cluster 28 has less than 2 members, skipping
Cluster 30 has less than 2 members, skipping
Cluster 31 has less than 2 members, skipping
Cluster 32 has less than 2 members, skipping
Cluster 34 has less than 2 members, skipping
Cluster 35 has less than 2 members, skipping
Cluster 36 has less than 2 members, skipping
Cluster 37 has less than 2 members, skipping
Cluster 38 has less than 2 members, skipping
Cluster 39 has less than 2 members, skipping
Cluster 41 has less than 2 members, skipping
Cluster 42 has less than 2 members, skipping
Cluster 43 has less than 2 members, skipping
Cluster 44 has less than 2 members, skipping
Cluster 46 has less than 2 members, skipping
Cluster 48 has less than 2 members, skipping
Cluster 49 has less than 2 members, skipping
Cluster 50 has less than 2 members, skipping
Cluster 52 has less than 2 members, skipping
Cluster 53 has less than 2 members, skipping
Cluster 54 has less than 2 members, skipping
Cluster 58

Evaluating closest to centroids:  85%|████████▍ | 11/13 [00:48<00:01,  1.44it/s]

Cluster 0 has less than 2 members, skipping
Cluster 1 has less than 2 members, skipping
Cluster 3 has less than 2 members, skipping
Cluster 4 has less than 2 members, skipping
Cluster 5 has less than 2 members, skipping
Cluster 6 has less than 2 members, skipping
Cluster 8 has less than 2 members, skipping
Cluster 9 has less than 2 members, skipping
Cluster 10 has less than 2 members, skipping
Cluster 11 has less than 2 members, skipping
Cluster 12 has less than 2 members, skipping
Cluster 13 has less than 2 members, skipping
Cluster 14 has less than 2 members, skipping
Cluster 15 has less than 2 members, skipping
Cluster 16 has less than 2 members, skipping
Cluster 17 has less than 2 members, skipping
Cluster 22 has less than 2 members, skipping
Cluster 23 has less than 2 members, skipping
Cluster 24 has less than 2 members, skipping
Cluster 26 has less than 2 members, skipping
Cluster 27 has less than 2 members, skipping
Cluster 28 has less than 2 members, skipping
Cluster 29 has les

Evaluating closest to centroids:  92%|█████████▏| 12/13 [00:48<00:00,  1.75it/s]

Cluster 275 has less than 2 members, skipping
Cluster 276 has less than 2 members, skipping
Cluster 277 has less than 2 members, skipping
Cluster 278 has less than 2 members, skipping
Cluster 279 has less than 2 members, skipping
Cluster 280 has less than 2 members, skipping
Cluster 281 has less than 2 members, skipping
Cluster 282 has less than 2 members, skipping
Cluster 283 has less than 2 members, skipping
Cluster 284 has less than 2 members, skipping
Cluster 285 has less than 2 members, skipping
Cluster 286 has less than 2 members, skipping
Cluster 287 has less than 2 members, skipping
Cluster 288 has less than 2 members, skipping
Cluster 289 has less than 2 members, skipping
Cluster 290 has less than 2 members, skipping
Cluster 291 has less than 2 members, skipping
Cluster 292 has less than 2 members, skipping
Cluster 293 has less than 2 members, skipping
Cluster 294 has less than 2 members, skipping
Cluster 295 has less than 2 members, skipping
Cluster 296 has less than 2 member

Evaluating closest to centroids: 100%|██████████| 13/13 [00:49<00:00,  3.78s/it]

Cluster 83 has less than 2 members, skipping
Cluster 84 has less than 2 members, skipping
Cluster 85 has less than 2 members, skipping
Cluster 86 has less than 2 members, skipping
Cluster 87 has less than 2 members, skipping
Cluster 88 has less than 2 members, skipping
Cluster 89 has less than 2 members, skipping
Cluster 90 has less than 2 members, skipping
Cluster 91 has less than 2 members, skipping
Cluster 92 has less than 2 members, skipping
Cluster 93 has less than 2 members, skipping
Cluster 94 has less than 2 members, skipping
Cluster 95 has less than 2 members, skipping
Cluster 96 has less than 2 members, skipping
Cluster 97 has less than 2 members, skipping
Cluster 98 has less than 2 members, skipping
Cluster 99 has less than 2 members, skipping
Cluster 100 has less than 2 members, skipping
Cluster 101 has less than 2 members, skipping
Cluster 102 has less than 2 members, skipping
Cluster 103 has less than 2 members, skipping
Cluster 104 has less than 2 members, skipping
Clust




In [None]:
# closest_to_centroids_filename = f"/home/can/feature_clustering/clusters/closest_to_centroids-1k_{param_summary}.json"
# with open(closest_to_centroids_filename, "w") as f:
#     json.dump(closest_to_centroids, f)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


RuntimeError: grad can be implicitly created only for scalar outputs