# Imports and Constants

In [None]:
import pandas as pd
from pathlib import Path
import torch
import transformer_lens
import gc

#MODEL_NAME = "EleutherAI/pythia-410m"
MODEL_NAME = "EleutherAI/pythia-2.8b"

# Collect data

Read in names of basketballers and check their name is 2 tokens exactly.

In [None]:
# data_dir = Path("data")
# names_file = data_dir / "names.txt"
# all_new_names = []

# # Iterate over all csv files in the data directory
# for csv_path in data_dir.glob("*.csv"):
#     # Read the csv file, assuming no header
#     df = pd.read_csv(csv_path, header=None)
#     # Get the second column (index 1) and drop any missing values
#     names_from_csv = df.iloc[:, 1].dropna().tolist()
#     filtered_names = [name for name in names_from_csv if isinstance(name, str) and name.count(' ') == 1]
#     all_new_names.extend(filtered_names)

# # Append the collected names to the names.txt file
# if all_new_names:
#     with open(names_file, "a") as f:
#         # Ensure we start on a new line if the file isn't empty
#         if names_file.stat().st_size > 0:
#             f.write("\n")
#         f.write("\n".join(all_new_names))
#     print(f"Appended {len(all_new_names)} names to {names_file}.")
# else:
#     print("No new names found in CSV files to append.")

In [None]:
names_file = Path("data/names.txt")
with open(names_file, 'r') as f:
    names = f.read().strip().split('\n')
    print(names)

In [None]:
if "model" in locals() or "model" in globals():
    del model

torch.cuda.empty_cache()

gc.collect()

In [None]:
if 'model' in locals() or 'model' in globals():
    del model

torch.cuda.empty_cache()

gc.collect()

# 4. Verify memory is cleared
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# More detailed GPU memory info
print(torch.cuda.memory_summary())

# Load the Pythia-2.8b model
# This will download the model weights if they are not already cached.
model = transformer_lens.HookedTransformer.from_pretrained_no_processing(
    MODEL_NAME,
    device="cuda" if torch.cuda.is_available() else "cpu",
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
    
print(f"{MODEL_NAME} model loaded successfully.")

In [None]:
print(torch.cuda.is_available())
# 4. Verify memory is cleared
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# More detailed GPU memory info
print(torch.cuda.memory_summary())


In [None]:
# A list to store names that consist of exactly two tokens
two_token_names = []

for name in names:
    # Tokenize the name. The result includes a batch dimension.
    tokens = model.to_tokens(name)

    # Check if the number of tokens (the length of the second dimension) is 3
    # (<soe> token, first name, last name)
    if tokens.shape[1] == 3:
        two_token_names.append(name)
    else:
        print(f"Name '{name}' is represented by {tokens}.")

print(f"Found {len(two_token_names)} names that are exactly two tokens long.")
print(two_token_names)

In [None]:
# basketball_players = []
print("Completions for basketball players:")
print("---------------------------------------")

#two_token_names = player_names
basketball_players = []

for name in two_token_names:
    prompt = f"Fact: {name} is known for playing the sport of"

    # Generate a completion for the prompt
    # We generate a few tokens and set temperature to 0 for deterministic output
    completion = model.generate(
        prompt, 
        max_new_tokens=2, 
        temperature=0,
        verbose=True
    )

    # Extract just the generated part of the text
    completion_text = completion[len(prompt):].strip()

    # Check if the completion is 'basketball' (case-insensitive)
    if completion_text.lower().startswith("basketball"):
        basketball_players.append(name)
    else:
        print(completion)
        print(completion_text)
    print(f"Prompt: '{prompt}' -> Completion: '{completion_text}'")


print("\n---------------------------------------")
print(f"Found {len(basketball_players)} names that completed with 'basketball':")
print(basketball_players)

basketball_players_file = Path("data/basketball_players.txt")
with open(basketball_players_file, "w") as f:
    f.write("\n".join(basketball_players))
print(f"Stored {len(basketball_players)} basketball players in {basketball_players_file}.")

In [None]:
print(f"{len(basketball_players)} basketball players identified and stored in {basketball_players_file}.")

In [None]:
import itertools
import random

# Use the list of confirmed basketball players from the previous step
# Extract first and last names
first_names = list(set([name.split(' ')[0] for name in two_token_names]))
last_names = list(set([name.split(' ')[1] for name in two_token_names]))

# Create all possible name combinations
all_combinations = [" ".join(p) for p in itertools.product(first_names, last_names)]

# Filter out the names that actually exist
real_names_set = set(two_token_names)
fake_names = [name for name in all_combinations if name not in real_names_set]
# Pick 100 random fake names
fake_names = random.sample(fake_names, min(100, len(fake_names)))

print(f"Generated {len(fake_names)} fake names to test.")

non_basketball_fake_players = []
non_basketball_completions = []

# Run basketball detection on the fake names
for i, name in enumerate(fake_names):
    prompt = f"Fact: {name} is known for playing the sport of"

    completion = model.generate(
        prompt, 
        max_new_tokens=2, 
        temperature=0,
        verbose=False
    )

    completion_text = completion[len(prompt):].strip()

    # Keep the names that do NOT complete with "basketball"
    if not completion_text.lower().startswith("basketball"):
        non_basketball_fake_players.append(name)
        non_basketball_completions.append(completion)
    if i%10 == 9:
        print(f"Processed {i+1} fake names... Ex: {completion}")

print(f"\nFound {len(non_basketball_fake_players)} fake names that are not associated with basketball.")
print("A sample of non-basketball completions for fake names:")
print("----------------------------------------------------")

# Define the sample size
sample_size = min(20, len(non_basketball_completions))

# Print a random sample of the results
if sample_size > 0:
    for i in random.sample(range(len(non_basketball_completions)), sample_size):
        print(non_basketball_completions[i])

fake_basketball_players_file = Path("data/fake_basketball_players.txt")
with open(fake_basketball_players_file, "w") as f:
    f.write("\n".join(non_basketball_fake_players))
print(f"Stored {len(non_basketball_fake_players)} fake basketball players in {fake_basketball_players_file}.")

# Compute and cache Embeddings

In [None]:
import torch
import transformer_lens
import gc

# Clear memory
if 'model' in locals() or 'model' in globals():
    del model
torch.cuda.empty_cache()
gc.collect()

# Load model with memory optimizations
model = transformer_lens.HookedTransformer.from_pretrained_no_processing(
    MODEL_NAME, device="cuda", dtype=torch.float16
)

# Function to process one name with memory optimization
def process_with_cache(name):
    prompt = f"Fact: {name}"
    
    # Temporary GPU usage with tight memory control
    # with torch.inference_mode():
        # model.to("cuda")
        
        # Key optimizations in run_with_cache:
        # - remove_batch_dim=True reduces tensor dimensions
        # - names_filter lets you select specific layers to cache
        # - incl_bwd=False avoids caching backward hooks
    _, cache = model.run_with_cache(
        prompt, 
        remove_batch_dim=True,
        names_filter=None,  # Set to specific layers if needed
        incl_bwd=False
    )
        
        # # Immediately move cache to CPU
        # for k, v in cache.items():
        #     cache[k] = v.detach().cpu()
        
        # Move model back to CPU and clear GPU
        # model.to("cpu")
        # torch.cuda.empty_cache()
    
    return cache

def compute_caches(names):
    caches = {}
    with torch.inference_mode():
        # Process each name individually and save to disk
        for i, name in enumerate(names):
            try:
                caches[name] = process_with_cache(name)
                # Save each cache separately to avoid accumulating in memory
                # torch.save(cache, f"cache_{i}_{name.replace(' ', '_')}.pt")
            except Exception as e:
                print(f"Error processing {name}: {e}")

            # Force cleanup
            gc.collect()
            torch.cuda.empty_cache()
    return caches

# Load player names
with open("data/basketball_players.txt", "r") as f:
    player_names = [line.strip() for line in f if line.strip()]
with open("data/fake_basketball_players.txt", "r") as f:
    fake_names = [line.strip() for line in f if line.strip()]

player_caches = compute_caches(player_names)
print(f"Processed {len(player_caches)} player caches.")
fake_caches = compute_caches(fake_names)
print(f"Processed {len(fake_caches)} fake player caches.")

# Compute convex covers

In [None]:
import numpy as np

def make_point_thick(vector, extra_vectors = 100, thickness=0.2):
    # Create the points +- thickness in each dimension, return vector of all points and original
    vectors = [vector]
    for dim in range(min(extra_vectors, len(vector))):
        # Create a new vector with the current dimension increased and decreased by thickness
        pos_vector = vector.copy()
        pos_vector[dim] += thickness
        vectors.append(pos_vector)
    return vectors

def get_point_set(caches, hook_name = "blocks.0.hook_resid_post", extra_vectors = 100, thickness = 0.2):
    # First, check the shape of one vector to initialize properly
    sample_name = list(caches.keys())[0]
    sample_vector = caches[sample_name][hook_name][-1]
    vector_dim = sample_vector.shape[0]
    print(f"Vector dimension: {vector_dim}")

    # Create a list to hold all vectors
    vectors = []

    # Extract vectors from all players
    print("Collecting vectors from cache...")
    for name in caches.keys():
        try:
            # Extract the vector for this player at the last position [-1]
            vector = caches[name][hook_name][-1]
            
            # Make sure it's on CPU and detached (in case it isn't already)
            vector = vector.detach().cpu()
            # Convert the tensor to a numpy array
            vector = vector.numpy()
            # Add to our collection
            vectors.extend(make_point_thick(vector, extra_vectors = extra_vectors, thickness = thickness))
        except KeyError as e:
            print(f"Skipping {name}: Missing '{hook_name}' in cache")
        except IndexError as e:
            print(f"Skipping {name}: No last element in '{hook_name}'")
    return vectors
player_post_layer = get_point_set(player_caches, "blocks.0.hook_resid_post", extra_vectors=0)
fake_post_layer = get_point_set(fake_caches, "blocks.0.hook_resid_post", extra_vectors=0)
print(f"{len(player_post_layer)} player vectors collected, {len(fake_post_layer)} fake vectors collected.")

In [None]:
def print_distance_stats(set_a, set_b, stats_lists=None, equal = False):
    """
    Calculate and print distance statistics between two sets of vectors.
    
    Parameters:
    - set_a: First set of vectors
    - set_b: Second set of vectors
    - stats_lists: Optional dictionary containing lists to append stats to
                  with keys 'min', 'max', 'mean', 'median'
    """
    distances = []
    for i, a in enumerate(set_a):
        for j, b in enumerate(set_b):
            distance = np.linalg.norm(a - b)
            if equal and i == j: 
                # Skip distances between the same points
                continue
            distances.append(distance)
    distances.sort()
    print(
        f"Points in A: {len(set_a)}, in B: {len(set_b)}, total distances: {len(distances)}"
    )
    print(f"Min: {distances[0]:.4f}, Mean: {np.mean(distances):.4f}, Median: {np.median(distances):.4f}, Max: {distances[-1]:.4f}")
    
    if stats_lists is not None:
        if 'min' in stats_lists:
            stats_lists['min'].append(distances[0])
        if 'max' in stats_lists:
            stats_lists['max'].append(distances[-1])
        if 'mean' in stats_lists:
            stats_lists['mean'].append(np.mean(distances))
        if 'median' in stats_lists:
            stats_lists['median'].append(np.median(distances))


print_distance_stats(player_post_layer, fake_post_layer)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE, MDS
from sklearn.decomposition import PCA
import umap

# Combine all vectors (players and fake)
all_vectors = np.vstack([player_post_layer, fake_post_layer])
labels = ["Player"] * len(player_post_layer) + ["Fake"] * len(fake_post_layer)

# Method 1: t-SNE visualization
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embeddings_tsne = tsne.fit_transform(all_vectors)

# # Method 2: UMAP (often better preserves global structure)
# reducer = umap.UMAP(random_state=42)
# embeddings_umap = reducer.fit_transform(all_vectors)

# Create plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot t-SNE
sns.scatterplot(x=embeddings_tsne[:, 0], y=embeddings_tsne[:, 1], 
                hue=labels, ax=ax1, palette="Set1")
ax1.set_title("t-SNE Visualization")

# Plot UMAP
# sns.scatterplot(x=embeddings_umap[:, 0], y=embeddings_umap[:, 1], 
                # hue=labels, ax=ax2, palette="Set1")
# ax2.set_title("UMAP Visualization")

plt.tight_layout()

In [None]:
# Calculate within-group and between-group distances
player_player_distances = []
fake_fake_distances = []
player_fake_distances = []

for i in range(len(all_vectors)):
    for j in range(i+1, len(all_vectors)):
        dist = np.linalg.norm(all_vectors[i] - all_vectors[j])
        if i < len(player_post_layer) and j < len(player_post_layer):
            player_player_distances.append(dist)
        elif i >= len(player_post_layer) and j >= len(player_post_layer):
            fake_fake_distances.append(dist)
        else:
            player_fake_distances.append(dist)

# Plot distribution of distances
plt.figure(figsize=(10, 6))
sns.kdeplot(player_player_distances, label="Player-Player", fill=True)
sns.kdeplot(fake_fake_distances, label="Fake-Fake", fill=True)
sns.kdeplot(player_fake_distances, label="Player-Fake", fill=True)
plt.xlabel("Euclidean Distance")
plt.ylabel("Density")
plt.title("Distribution of Distances Between Vector Groups")
plt.legend()

In [None]:
from convex_point_cover.algorithms.kruskal import kruskal

num_layers = model.W_K.shape[0]

layers = range(num_layers)
stats = {'min': [], 'max': [], 'mean': [], 'median': []}
stats_player_to_0 = {'min': [], 'max': [], 'mean': [], 'median': []}
stats_fake_to_0 = {'min': [], 'max': [], 'mean': [], 'median': []}

cluster_sizes = [[], []]
for layer in range(0, num_layers, 1):
    player_post_layer = get_point_set(player_caches, f"blocks.{layer}.hook_resid_post", extra_vectors = 0, thickness = 0.2)
    fake_post_layer = get_point_set(fake_caches, f"blocks.{layer}.hook_resid_post", extra_vectors = 0, thickness = 0.2)
    print(
        f"Layer {layer}: {len(player_post_layer)} player vectors collected, {len(fake_post_layer)} fake vectors collected."
    )

    print("Distances between player vectors and fake vectors:")
    print_distance_stats(player_post_layer, fake_post_layer, stats_lists=stats)
    zero_vector = np.zeros_like(player_post_layer[0])
    print("Distances to 0:")
    print_distance_stats(
        player_post_layer,
        [zero_vector],
        stats_lists=stats_player_to_0,
    )
    print("fakes to 0:")
    print_distance_stats(
        fake_post_layer,
        [zero_vector],
        stats_lists=stats_fake_to_0,
    )

# Plot the evolution of distances across layers
plt.figure(figsize=(12, 6))
plt.plot(layers, stats['min'], marker="o", label="Min Distance")
plt.plot(layers, stats['median'], marker="s", label="Median Distance")
plt.plot(layers, stats['mean'], marker="^", label="Mean Distance")
plt.plot(layers, stats['max'], marker="*", label="Max Distance")
plt.xlabel("Layer")
plt.ylabel("Distance between points")
plt.title("Distance Metrics Across Layers")
plt.legend()
plt.grid(True, alpha=0.3)

plt.figure(figsize=(12, 6))

plt.plot(layers, stats_player_to_0["min"], marker="o", label="Min Distance")
plt.plot(layers, stats_player_to_0["median"], marker="s", label="Median Distance")
plt.plot(layers, stats_player_to_0["mean"], marker="^", label="Mean Distance")
plt.plot(layers, stats_player_to_0["max"], marker="*", label="Max Distance")
plt.xlabel("Layer")
plt.ylabel("Distance from 0")
plt.title("Magnitude Metrics Across Layers for real names")
plt.legend()
plt.grid(True, alpha=0.3)

plt.figure(figsize=(12, 6))

plt.plot(layers, stats_fake_to_0["min"], marker="o", label="Min Distance")
plt.plot(layers, stats_fake_to_0["median"], marker="s", label="Median Distance")
plt.plot(layers, stats_fake_to_0["mean"], marker="^", label="Mean Distance")
plt.plot(layers, stats_fake_to_0["max"], marker="*", label="Max Distance")
plt.xlabel("Layer")
plt.ylabel("Distance from 0")
plt.title("Magnitude Metrics Across Layers for fake names")
plt.legend()
plt.grid(True, alpha=0.3)

plt.figure(figsize=(12, 6))

In [None]:
# Check how meaningfully different prefixed tokenizations are
name = "Michael Jordan"

for i in range(50):
    prefix = " " * i
    prompt = f"Fact: {prefix}{name}"

    tokens = model.to_tokens(prompt)
    print(f"Tokens for '{prompt}': {tokens}")
    

In [None]:
# Check how meaningfully different prefixed embeddings are
name = "Michael Jordan"
prefixed_names = [" " * i + name for i in range(50)]
cache = compute_caches(prefixed_names)

In [None]:
import matplotlib.pyplot as plt

stats = {"min": [], "max": [], "mean": [], "median": []}

for layer in range(model.W_V.shape[0]):
    vectors = []
    for names in prefixed_names:
        vectors.append(cache[names][f"blocks.{layer}.hook_resid_post"].detach().cpu().numpy()[-1])
    print_distance_stats(vectors, vectors, stats_lists=stats, equal = True)

# Plot the evolution of distances across layers
plt.figure(figsize=(12, 6))
plt.plot(layers, stats["min"], marker="o", label="Min Distance")
plt.plot(layers, stats["median"], marker="s", label="Median Distance")
plt.plot(layers, stats["mean"], marker="^", label="Mean Distance")
plt.plot(layers, stats["max"], marker="*", label="Max Distance")
plt.xlabel("Layer")
plt.ylabel("Distance between points")
plt.title("Distance of space-prefixed Embeddings Across Layers")
plt.legend()
plt.grid(True, alpha=0.3)

plt.figure(figsize=(12, 6))

# Increase data size

In [None]:
# Create a list to store results
results_per_prefix = []
results_per_name = {name.strip():0 for name in player_names}

# Loop through different numbers of spaces
for num_spaces in range(0, 1000, 10):
    prefix = " " * num_spaces
    basketball_count = 0
    total_count = 0
    
    # Test each player name
    for name in results_per_name.keys():
        prompt = f"Fact: {prefix}{name} is known for playing the sport of"
        completion = model.generate(
            prompt, 
            max_new_tokens=2, 
            temperature=0,
            verbose=False
        )
        completion_text = completion[len(prompt):].strip()
        
        # Count basketball completions
        if completion_text.lower().startswith("basketball"):
            basketball_count += 1
            results_per_name[name] += 1
        total_count += 1
        
        # Print only a sample of the results to avoid flooding output
        if len(results_per_prefix) < 3 or num_spaces % 10 == 0:
            print(f"Prompt: '{prompt}' -> Completion: '{completion_text}'")
    
    # Store the results for this number of spaces
    percentage = (basketball_count / total_count) * 100 if total_count > 0 else 0
    results_per_prefix.append((num_spaces, basketball_count, total_count, percentage))
    print(f"Spaces: {num_spaces}, Basketball: {basketball_count}/{total_count} ({percentage:.2f}%)")

# Plot the results
import matplotlib.pyplot as plt

spaces = [r[0] for r in results_per_prefix]
percentages = [r[3] for r in results_per_prefix]

plt.figure(figsize=(12, 6))
plt.plot(spaces, percentages, marker='o')
plt.title('Percentage of "basketball" completions vs. number of spaces')
plt.xlabel('Number of spaces')
plt.ylabel('Percentage of basketball completions')
plt.grid(True)
plt.show()

# Sort players by their count of basketball completions
sorted_players = sorted(results_per_name.items(), key=lambda x: x[1], reverse=True)
player_names_sorted = [player[0] for player in sorted_players]
completion_counts = [player[1] for player in sorted_players]

# Create a horizontal bar chart for better readability with many player names
plt.figure(figsize=(10, 12))
bars = plt.barh(player_names_sorted, completion_counts)

# Add count labels to the bars
for i, bar in enumerate(bars):
    plt.text(bar.get_width() + 0.2, bar.get_y() + bar.get_height()/2, 
             str(completion_counts[i]), 
             va='center')

plt.xlabel('Number of "basketball" completions')
plt.ylabel('Player Name')
plt.title('Sensitivity of Player Names to Spacing (steps of 10)\n(Higher = More Robust Basketball Association)')
plt.tight_layout()
plt.show()

# Summary statistics
print(f"Average basketball completions per player: {sum(completion_counts)/len(completion_counts):.2f}")
print(f"Max basketball completions: {max(completion_counts)}")
print(f"Min basketball completions: {min(completion_counts)}")

In [None]:
from convex_point_cover.algorithms.fast_kruskal import fast_kruskal

from joblib import Parallel, delayed

num_layers = model.W_K.shape[0]
layers = range(num_layers)

def process_layer(layer, player_vectors, fake_vectors):
    positive_clusters = fast_kruskal(
        player_vectors,
        fake_vectors,
        epsilon=0.1,
        debug=False,
    )
    negative_clusters = fast_kruskal(
        fake_vectors,
        player_vectors,
        epsilon=0.1,
        debug=False,
    )
    return {
        "layer": layer,
        "player_count": len(player_vectors),
        "fake_count": len(fake_vectors),
        "positive_clusters": len(positive_clusters),
        "negative_clusters": len(negative_clusters),
    }


# Run parallel computation - will use all available cores
print("Processing layers in parallel...")
results = Parallel(n_jobs=-1, verbose=10)(
    delayed(process_layer)(
        layer,
        get_point_set(
            player_caches,
            f"blocks.{layer}.hook_resid_post",
            extra_vectors=0,
            thickness=0.2,
        ),
        get_point_set(
            fake_caches,
            f"blocks.{layer}.hook_resid_post",
            extra_vectors=0,
            thickness=0.2,
        ),
    )
    for layer in layers
)

# Process and display results
cluster_sizes = [[], []]
for result in sorted(results, key=lambda x: x["layer"]):
    layer = result["layer"]
    player_count = result["player_count"]
    fake_count = result["fake_count"]
    pos_clusters = result["positive_clusters"]
    neg_clusters = result["negative_clusters"]

    # Store results
    cluster_sizes[0].append(pos_clusters)
    cluster_sizes[1].append(neg_clusters)

    # Print layer info
    print(
        f"Layer {layer}: {player_count} player vectors collected, {fake_count} fake vectors collected."
    )
    print(f"{pos_clusters} clusters found.")
    print(f"{neg_clusters} negative clusters found.")

# Plot the number of clusters found across layers
plt.figure(figsize=(12, 6))
plt.plot(layers, cluster_sizes[0], marker="o", label="Positive Clusters")
plt.plot(layers, cluster_sizes[1], marker="s", label="Negative Clusters")
plt.xlabel("Layer")
plt.ylabel("Number of Clusters")
plt.title("Number of Clusters Found Across Layers")
plt.legend()
plt.grid(True, alpha=0.3)

# Linear classifier

In [None]:
from sklearn.svm import LinearSVC
import numpy as np


def count_linear_separation(set_a, set_b):
    # Combine all vectors into a dataset and create labels
    X = np.vstack([set_a, set_b])
    y = np.array(["Player"] * len(set_a) + ["Fake"] * len(set_b))

    # Train a linear SVM classifier
    clf = LinearSVC(random_state=42, max_iter=10000)
    clf.fit(X, y)

    # Make predictions
    y_pred = clf.predict(X)

    # Count overall misclassifications
    misclassified = np.sum(y_pred != y)
    print(f"Total misclassified points: {misclassified} out of {len(y)} ({misclassified/len(y)*100:.2f}%)")

    # Get misclassified by class
    fake_misclassified = np.sum((y_pred != y) & (y_pred == "Player") & (y == "Fake"))
    player_misclassified = np.sum((y_pred != y) & (y_pred == "Fake") & (y == "Player"))

    print(f"Fake points misclassified as players: {fake_misclassified}")
    print(f"Player points misclassified as fake: {player_misclassified}")
    return fake_misclassified, player_misclassified

print(count_linear_separation(player_post_layer, fake_post_layer))

# Putting the pieces together

In [1]:
import pandas as pd
from pathlib import Path
import torch
import transformer_lens
import gc
import itertools
import random
from sklearn.svm import LinearSVC
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns


# MODEL_NAME = "EleutherAI/pythia-410m"
MODEL_NAME = "EleutherAI/pythia-2.8b"

# Data generation
READ_FROM_FILE = True # If false, re-check all basketballer predictions
DATA_FOLDER = "data-2.8b"
SPACE_RANGE = range(7, 5000, 7) # How many different spaces real data points get
FAKE_NAMES_MAX = 2000 # how many different fake names to generate
NUM_PREFIXES = 100 # How many different spaces fake data points get
LIMIT = 20000 # Max number of total data points

MAX_BATCH_SIZE = 1000

NUM_JOBS = -1 # How many parallel jobs for computing convex sets (max 1 per layer, -1 for all cores)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if "model" in locals() or "model" in globals():
    del model

torch.cuda.empty_cache()

gc.collect()

# 4. Verify memory is cleared
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# Load the Pythia-2.8b model
# This will download the model weights if they are not already cached.
model = transformer_lens.HookedTransformer.from_pretrained_no_processing(
    MODEL_NAME,
    device="cuda" if torch.cuda.is_available() else "cpu",
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

print(f"{MODEL_NAME} model loaded successfully.")

GPU memory allocated: 0.00 MB
GPU memory reserved: 0.00 MB
Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer
EleutherAI/pythia-2.8b model loaded successfully.


In [3]:
names_file = Path(f"{DATA_FOLDER}/names.txt")
with open(names_file, "r") as f:
    names = f.read().strip().split("\n")
    print(f"{len(names)} names")

3962 names


In [4]:
# A list to store names that consist of exactly two tokens
two_token_names = []

for name in names:
    # Tokenize the name. The result includes a batch dimension.
    tokens = model.to_tokens(name)

    # Check if the number of tokens (the length of the second dimension) is 3
    # (<soe> token, first name, last name)
    if tokens.shape[1] == 3:
        two_token_names.append(name)

print(f"Found {len(two_token_names)} names that are exactly two tokens long.")
print(two_token_names[:10])

Found 372 names that are exactly two tokens long.
['Alex English', 'Gary Neal', 'James Johnson', 'Gordon Herbert', 'Robert Parish', 'Charles Smith', 'Rod Strickland', 'Kevin Edwards', 'Scott Brooks', 'Tony Gonzalez']


In [5]:
# Make absolutely sure the strings are identical
text = "Fact: Michael Jordan is known for playing the sport of"

# Single string
result1 = model.generate(text, max_new_tokens=5, temperature=0, verbose=False)
print(result1)

# Batch with identical strings
result2 = model.generate(
    [
        text,
        "Fact:  Michael Jordan is known for playing the sport of",
    ],  # Use the exact same string variable
    max_new_tokens=1,
    temperature=0,
    verbose=False,
    padding_side='left'
)
print(result2)

Fact: Michael Jordan is known for playing the sport of basketball. He is also
['Fact: Michael Jordan is known for playing the sport of basketball', 'Fact:  Michael Jordan is known for playing the sport of basketball']


In [6]:
def is_basketball_player(name):
    prompt = f"Fact: {name} is known for playing the sport of"
    completion = model.generate(prompt, max_new_tokens=1, temperature=0, verbose=False)
    completion_text = completion[len(prompt) :].strip()
    return completion_text.lower().startswith("basketball")

prefix_candidates = [" " * i for i in SPACE_RANGE]

if READ_FROM_FILE:
    # Read basketball players from file
    basketball_players_file = Path(f"{DATA_FOLDER}/basketball_players.txt")
    with open(basketball_players_file, "r") as f:
        basketball_players = [line.rstrip() for line in f if line.rstrip()]
    print(f"Loaded {len(basketball_players)} basketball players from {basketball_players_file}")
else:
    # basketball_players = []
    print("Completions for basketball players:")
    print("---------------------------------------")

    # two_token_names = player_names
    basketball_players = []

    for name in two_token_names:
        if not is_basketball_player(name):
            continue
        basketball_players.append(name)
        print(".", end="")
        # Player without prefix passed, try to generate more data points:
        prefixed_names = [prefix + name for prefix in prefix_candidates]
        attempted = 0
        found = 0
        for prefixed_name in prefixed_names:
            attempted += 1
            if is_basketball_player(prefixed_name):
                basketball_players.append(prefixed_name)
                found += 1
            if attempted == 10 and found == 0:
                # If I can't find a second data point in the first 10 tries, abort
                break

    print("\n---------------------------------------")
    print(f"Found {len(basketball_players)} names that completed with 'basketball':")
    print(basketball_players[:10])

    random.shuffle(basketball_players)

    basketball_players_file = Path(f"{DATA_FOLDER}/basketball_players.txt")
    with open(basketball_players_file, "w") as f:
        f.write("\n".join(basketball_players))
    print(
        f"Stored {len(basketball_players)} basketball players in {basketball_players_file}."
    )

Loaded 41939 basketball players from data-2.8b/basketball_players.txt


In [7]:
# READ_FROM_FILE = False

if READ_FROM_FILE:
    fake_basketball_players_file = Path(f"{DATA_FOLDER}/fake_basketball_players.txt")
    with open(fake_basketball_players_file, "r") as f:
        non_basketball_players = [line.rstrip() for line in f if line.rstrip()]
    print(
        f"Loaded {len(non_basketball_players)} basketball players from {fake_basketball_players_file}"
    )

else:
    # Use the list of confirmed basketball players from the previous step
    # Extract first and last names
    first_names = list(set([name.split(" ")[0] for name in basketball_players]))
    last_names = list(set([name.split(" ")[1] for name in basketball_players]))

    # Create all possible name combinations
    all_combinations = [" ".join(p) for p in itertools.product(first_names, last_names)]

    # Filter out the names that actually exist
    real_names_set = set(basketball_players)
    fake_names = [name for name in all_combinations if name not in real_names_set]
    # Pick 100 random fake names
    fake_names = random.sample(fake_names, min(FAKE_NAMES_MAX, len(fake_names)))
    # Add random prefixes:
    prefixed_fake_names = [prefix + name for name in fake_names for prefix in random.sample(prefix_candidates, NUM_PREFIXES)]

    # Combine original fake names with prefixed ones
    fake_names = fake_names + prefixed_fake_names

    print(f"Generated {len(fake_names)} fake names to test.")
    print(f"{fake_names[-20:]}")

    non_basketball_fake_players = []
    non_basketball_completions = []

    # Run basketball detection on the fake names
    for i, name in enumerate(fake_names):
        if not is_basketball_player(name):
            non_basketball_fake_players.append(name)
        if i % 100 == 9:
            print(".", end="")

    random.shuffle(non_basketball_fake_players)

    print(
        f"\nFound {len(non_basketball_fake_players)} fake names that are not associated with basketball."
    )

    fake_basketball_players_file = Path(f"{DATA_FOLDER}/fake_basketball_players.txt")
    with open(fake_basketball_players_file, "w") as f:
        f.write("\n".join(non_basketball_fake_players))
    print(
        f"Stored {len(non_basketball_fake_players)} fake basketball players in {fake_basketball_players_file}."
    )

Loaded 129893 basketball players from data-2.8b/fake_basketball_players.txt


In [8]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
    

In [9]:
# Switching to batched processing. Different token lengths mess up the values via padding, process
# same lengths together, this checks that the results are the same.
input = ["Prompt: Michael Jordan", "Prompt: Steve Curry"]

logits, v = model.run_with_cache(
    input,
    remove_batch_dim=False,
    names_filter=lambda layer: "hook_resid_post" in layer
    or "blocks.0.hook_mlp_in" in layer,  # Set to specific layers if needed
    incl_bwd=False,
    return_cache_object=False,
)

mj = v["blocks.0.hook_resid_post"][0]

logits2, v2 = model.run_with_cache(
    input[0],
    remove_batch_dim=True,
#     names_filter=lambda layer: "hook_resid_post" in layer
#     or "blocks.0.hook_mlp_in" in layer,  # Set to specific layers if needed
    incl_bwd=False,
    return_cache_object=False,
)

print(v['blocks.0.hook_resid_post'][0])
for l, v in v2.items():
    if "blocks.0" in l:
        print(f"{l}: {v.shape}")

a = v2["blocks.0.hook_resid_pre"][-1]
b = v2["blocks.0.hook_attn_out"][-1]
c = v2["blocks.0.hook_mlp_out"][-1]
print(f"blocks.0.hook_resid_pre: {v2['blocks.0.hook_resid_pre'][-1]}")
print(f"blocks.0.hook_attn_out: {v2['blocks.0.hook_attn_out'][-1]}")
print(f"blocks.0.hook_mlp_out: {v2['blocks.0.hook_mlp_out'][-1]}")
print(f"blocks.0.hook_resid_post: {v2['blocks.0.hook_resid_post'][-1]}")
print(f"{a+b+c}")

tensor([[ 0.6602, -0.3262,  0.0781,  ..., -0.3164, -0.6797,  1.2266],
        [ 0.4297, -0.0430, -0.3379,  ..., -0.9883, -0.7070, -0.1836],
        [ 0.1992,  0.8477,  0.1543,  ...,  0.6016,  0.2324, -0.2295],
        [-0.2891,  0.1484, -0.1592,  ..., -0.0361,  0.1963,  0.0898],
        [-0.0352,  0.3164, -0.1152,  ...,  0.3066,  0.3203,  0.3340],
        [ 0.2930, -0.0649,  0.0771,  ..., -0.1904,  0.7891,  0.5625]],
       device='cuda:0', dtype=torch.bfloat16)
blocks.0.hook_resid_pre: torch.Size([6, 2560])
blocks.0.ln1.hook_scale: torch.Size([6, 1])
blocks.0.ln1.hook_normalized: torch.Size([6, 2560])
blocks.0.attn.hook_q: torch.Size([6, 32, 80])
blocks.0.attn.hook_k: torch.Size([6, 32, 80])
blocks.0.attn.hook_v: torch.Size([6, 32, 80])
blocks.0.attn.hook_rot_q: torch.Size([6, 32, 80])
blocks.0.attn.hook_rot_k: torch.Size([6, 32, 80])
blocks.0.attn.hook_attn_scores: torch.Size([32, 6, 6])
blocks.0.attn.hook_pattern: torch.Size([32, 6, 6])
blocks.0.attn.hook_z: torch.Size([6, 32, 80])


In [None]:
# Clear memory
if "model" in locals() or "model" in globals():
    del model

torch.cuda.empty_cache()
gc.collect()

# Load model with memory optimizations
model = transformer_lens.HookedTransformer.from_pretrained_no_processing(
    MODEL_NAME, device="cuda", dtype=torch.float16
)

def sort_into_batches(names):
    per_length = {}
    for name in names:
        prompt = f"Fact: {name}"
        l = len(model.to_tokens(prompt)[0])
        # print(f"l={l} for {name}")
        if l in per_length:
            per_length[l].append(name)
        else:
            per_length[l] = [name]
    return per_length


# Function to process one name with memory optimization
def process_with_cache(names, caches):
    batch = [f"Fact: {name}" for name in names]

    _, cache = model.run_with_cache(
        batch,
        remove_batch_dim=False,
        names_filter=lambda layer: "hook_resid_post" in layer
        or "blocks.0.hook_resid_pre" in layer
        or "blocks.0.hook_attn_out",  # Set to specific layers if needed
        incl_bwd=False,
        return_cache_object=False,
    )
    # The mid-layer activation is the sum of the pre-activation and attention output,
    # the right hook does not exist for some reason.
    for i, name in enumerate(names):
        a = cache['blocks.0.hook_resid_pre'][i][-1]
        b = cache['blocks.0.hook_attn_out'][i][-1]
        caches[name] = {'blocks.0.hook_resid_mid': a + b}
    for layer, v in cache.items():
        if "hook_resid_post" not in layer:
            continue
        for i, name in enumerate(names):
            # Move the last token's activation to CPU and detach
            caches[name][layer] = v[i][-1].detach().cpu()

    return caches

def compute_caches(names):
    batched_names = sort_into_batches(names)
    caches = {}
    with torch.inference_mode():
        for length, batch in batched_names.items():
            print(f"Processing batch of length {length} with {len(batch)} names...")
            for i in range(0, len(batch), MAX_BATCH_SIZE):
                # Process in batches of MAX_BATCH_SIZE
                small_batch = batch[i:i + MAX_BATCH_SIZE]
                # print(f"{batch}")
                # for i, name in enumerate(names):
                # Process each name individually and save to disk
                try:
                    process_with_cache(small_batch, caches)
                    # print(name)
                    # Save each cache separately to avoid accumulating in memory
                    # torch.save(caches[name], f"cache/cache_{i}_{name.replace(' ', '_')}.pt")
                except Exception as e:
                    print(f"Error processing batch for length {length}: {e}")

                # Force cleanup
                torch.cuda.empty_cache()
                gc.collect()

                # if i % 1000 == 0:
                #     print(f"Step {i} of {len(names)}...")
                # print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
                # print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

    return caches

# Load player names
with open(f"{DATA_FOLDER}/basketball_players.txt", "r") as f:
    player_names = [line.rstrip() for line in f if line.rstrip()]
with open(f"{DATA_FOLDER}/fake_basketball_players.txt", "r") as f:
    fake_names = [line.rstrip() for line in f if line.rstrip()]

player_caches = compute_caches(player_names[:LIMIT])
print(f"Processed {len(player_caches)} player caches.")
fake_caches = compute_caches(fake_names[:LIMIT])
print(f"Processed {len(fake_caches)} fake player caches.")

Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer
Processing batch of length 205 with 67 names...
Processing batch of length 135 with 95 names...
Processing batch of length 148 with 107 names...
Processing batch of length 150 with 116 names...
Processing batch of length 194 with 113 names...
Processing batch of length 195 with 77 names...
Processing batch of length 131 with 120 names...
Processing batch of length 146 with 83 names...
Processing batch of length 125 with 86 names...
Processing batch of length 147 with 95 names...
Processing batch of length 166 with 119 names...
Processing batch of length 156 with 83 names...
Processing batch of length 50 with 117 names...
Processing batch of length 74 with 90 names...
Processing batch of length 84 with 83 names...
Processing batch of length 160 with 76 names...
Processing batch of length 141 with 116 names...
Processing batch of length 52 with 90 names...
Processing batch of length 96 with 122 names...
Processing batc

In [None]:
for name in player_caches:
    for key in player_caches[name]:
        player_caches[name][key] = player_caches[name][key].detach().cpu()
        print(f"{name.strip()}.{key}: {player_caches[name][key].shape} {player_caches[name][key]}")
        break
    break

In [None]:
def get_layer_name(layer):
    layer_name = f"blocks.{layer}.hook_resid_post"
    if layer == -1:
        layer_name = "blocks.0.hook_resid_mid"
    return layer_name


def make_point_thick(vector, extra_vectors=100, thickness=0.2):
    # Create the points +- thickness in each dimension, return vector of all points and original
    vectors = [vector]
    for dim in range(min(extra_vectors, len(vector))):
        # Create a new vector with the current dimension increased and decreased by thickness
        pos_vector = vector.copy()
        pos_vector[dim] += thickness
        vectors.append(pos_vector)
    return vectors


def get_point_set(
    caches, hook_name="blocks.0.hook_resid_post", extra_vectors=0, thickness=0.2
):
    # Create a list to hold all vectors
    vectors = []

    # Extract vectors from all players
    print(f"Collecting vectors from layer {hook_name}...")
    for name in caches.keys():
        try:
            # Extract the vector for this player
            vector = caches[name][hook_name]

            # Make sure it's on CPU and detached (in case it isn't already)
            vector = vector.detach().cpu()
            # Convert the tensor to a numpy array
            vector = vector.numpy()
            # Add to our collection
            vectors.extend(
                make_point_thick(
                    vector, extra_vectors=extra_vectors, thickness=thickness
                )
            )
        except KeyError as e:
            print(f"Skipping {name}: Missing '{hook_name}' in cache")
        except IndexError as e:
            print(f"Skipping {name}: No last element in '{hook_name}'")
    return vectors


player_layer_0 = get_point_set(
    player_caches, "blocks.0.hook_resid_mid", extra_vectors=0
)
fake_layer_0 = get_point_set(
    fake_caches, "blocks.0.hook_resid_mid", extra_vectors=0
)

In [None]:
def count_linear_separation(set_a, set_b):
    # Combine all vectors into a dataset and create labels
    X = np.vstack([set_a, set_b])
    y = np.array(["Player"] * len(set_a) + ["Fake"] * len(set_b))

    # Train a linear SVM classifier
    clf = LinearSVC(random_state=42, max_iter=10000)
    clf.fit(X, y)

    # Make predictions
    y_pred = clf.predict(X)

    # Count overall misclassifications
    misclassified = np.sum(y_pred != y)
    print(
        f"Total misclassified points: {misclassified} out of {len(y)} ({misclassified/len(y)*100:.2f}%)"
    )

    # Get misclassified by class
    fake_misclassified = np.sum((y_pred != y) & (y_pred == "Player") & (y == "Fake"))
    player_misclassified = np.sum((y_pred != y) & (y_pred == "Fake") & (y == "Player"))

    print(f"Fake points misclassified as players: {fake_misclassified}")
    print(f"Player points misclassified as fake: {player_misclassified}")
    return fake_misclassified, player_misclassified


print(count_linear_separation(player_layer_0, fake_layer_0))

In [None]:
# Plot neighborhood visualization, compute linear separation
# Initialize lists to store results
layer_results = []
misclassification_rates = []
player_misclassified_counts = []
fake_misclassified_counts = []
accuracy_scores = []

# Get the number of layers
num_layers = model.W_K.shape[0]
print(f"Analyzing linear separability across {num_layers} layers...")

plot_every = 4

n_rows = (num_layers // plot_every + 2) // 3 + 1
fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=3,
    figsize=(12, 4 * n_rows),
)
ax_idx = 0

# Process each layer
for layer in range(-1, num_layers):
    # Get vectors from this layer using existing functions
    layer_name = get_layer_name(layer)
    player_layer_vectors = get_point_set(
        player_caches, layer_name, extra_vectors=0
    )
    fake_layer_vectors = get_point_set(fake_caches, layer_name, extra_vectors=0)

    # Check linear separability using existing function
    print(f"\nLayer {layer}:")
    fake_mis, player_mis = count_linear_separation(player_layer_vectors, fake_layer_vectors)

    # Calculate metrics
    total_points = len(player_layer_vectors) + len(fake_layer_vectors)
    total_mis = fake_mis + player_mis
    accuracy = 1 - (total_mis / total_points)

    # Store results
    layer_results.append({
        "layer": layer,
        "player_count": len(player_layer_vectors),
        "fake_count": len(fake_layer_vectors),
        "fake_misclassified": fake_mis,
        "player_misclassified": player_mis,
        "total_misclassified": total_mis,
        "accuracy": accuracy
    })

    misclassification_rates.append(total_mis / total_points)
    player_misclassified_counts.append(player_mis)
    fake_misclassified_counts.append(fake_mis)
    accuracy_scores.append(accuracy)

    # Generate TSNE visualization for selected layers (to avoid too many plots)
    if layer % plot_every == 0 or layer == -1 or layer == num_layers - 1:
        # Combine vectors for visualization
        all_vectors = np.vstack([player_layer_vectors, fake_layer_vectors])
        # labels = ["player"] * len(player_layer_vectors) + ["fake"] * len(fake_layer_vectors)
        # def describe_prefix(name):
        #     spaces = len(name) - len(name.lstrip())
        #     if spaces == 0:
        #         return "no prefix"
        #     return f"<{(spaces//500 + 1) * 500} spaces"
        def describe_prefix(name):
            first_letter = name.lstrip()[0]
            return f"starts with {first_letter}"
        labels = [describe_prefix(name) for name in player_caches.keys()] + [describe_prefix(name) for name in fake_caches.keys()]
        # Apply TSNE
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(all_vectors)//5))
        embeddings = tsne.fit_transform(all_vectors)

        # Create plot
        scatter = sns.scatterplot(x=embeddings[:, 0], y=embeddings[:, 1], 
                hue=labels, palette="Set1", ax = axs.flat[ax_idx])
        if layer != 20:
            scatter.legend_.remove()
        axs.flat[ax_idx].set_title(f"Layer {layer}")
        ax_idx += 1

plt.tight_layout()
plt.show()

# Plot misclassification rates across layers
plt.figure(figsize=(12, 6))
plt.plot(range(-1, num_layers), misclassification_rates, marker='o', color='red')
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7)
plt.xlabel('Layer')
plt.ylabel('Misclassification Rate')
plt.title('Linear Classifier Misclassification Rate by Layer')
plt.grid(True, alpha=0.3)
plt.show()

# Plot accuracy across layers
plt.figure(figsize=(12, 6))
plt.plot(range(-1, num_layers), accuracy_scores, marker='o', color='green')
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7)
plt.xlabel('Layer')
plt.ylabel('Classification Accuracy')
plt.title('Linear Classifier Accuracy by Layer')
plt.grid(True, alpha=0.3)
plt.show()

# Plot misclassifications by category
plt.figure(figsize=(12, 6))
x = np.arange(-1, num_layers)
width = 0.35
plt.bar(x - width/2, player_misclassified_counts, width, label='Players Misclassified as Fake')
plt.bar(x + width/2, fake_misclassified_counts, width, label='Fake Misclassified as Players')
plt.xlabel('Layer')
plt.ylabel('Number of Misclassifications')
plt.title('Misclassification Counts by Category')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Find best layer for linear separation
best_layer = np.argmax(accuracy_scores)
print(f"\nBest layer for linear separation: Layer {best_layer}")
print(f"Accuracy: {accuracy_scores[best_layer]*100:.2f}%")
print(f"Misclassification rate: {misclassification_rates[best_layer]*100:.2f}%")

In [None]:
def print_distance_stats(set_a, set_b, stats_lists=None, equal=False):
    """
    Calculate and print distance statistics between two sets of vectors.

    Parameters:
    - set_a: First set of vectors
    - set_b: Second set of vectors
    - stats_lists: Optional dictionary containing lists to append stats to
                  with keys 'min', 'max', 'mean', 'median'
    """
    distances = []
    for i, a in enumerate(set_a):
        for j, b in enumerate(set_b):
            distance = np.linalg.norm(a - b)
            if equal and i == j:
                # Skip distances between the same points
                continue
            distances.append(distance)
    distances.sort()
    print(
        f"Points in A: {len(set_a)}, in B: {len(set_b)}, total distances: {len(distances)}"
    )
    print(
        f"Min: {distances[0]:.4f}, Mean: {np.mean(distances):.4f}, Median: {np.median(distances):.4f}, Max: {distances[-1]:.4f}"
    )

    if stats_lists is not None:
        if "min" in stats_lists:
            stats_lists["min"].append(distances[0])
        if "max" in stats_lists:
            stats_lists["max"].append(distances[-1])
        if "mean" in stats_lists:
            stats_lists["mean"].append(np.mean(distances))
        if "median" in stats_lists:
            stats_lists["median"].append(np.median(distances))


print_distance_stats(player_layer_0[:1000], fake_layer_0[:1000])

In [None]:
from scipy.spatial.distance import cdist

def get_min_distance(points_a, points_b):
    distances = cdist(points_a, points_b, metric="euclidean")
    min_dist = np.min(distances)
    i, j = np.unravel_index(np.argmin(distances), distances.shape)
    return min_dist, i, j

# print(min_distance(player_layer_0, fake_layer_0))
# 16s: (np.float64(1.7858378876167593), np.int64(1760), np.int64(137))

In [None]:
from convex_point_cover.algorithms.fast_kruskal import fast_kruskal

from joblib import Parallel, delayed

num_layers = model.W_K.shape[0]
layers = range(-1, num_layers)


def process_layer(layer, player_vectors, fake_vectors):
    print(f"Processing layer {layer}...")
    min_distance, _, _ = get_min_distance(player_vectors, fake_vectors)
    print(f"Layer {layer}: min distance={min_distance}")
    
    delta = min_distance - 0.11

    positive_clusters = fast_kruskal(
        player_vectors,
        fake_vectors,
        epsilon=0.1,
        delta=delta,
        debug=False,
    )
    negative_clusters = fast_kruskal(
        fake_vectors,
        player_vectors,
        epsilon=0.1,
        delta=delta,
        debug=False,
    )
    return {
        "layer": layer,
        "player_count": len(player_vectors),
        "fake_count": len(fake_vectors),
        "positive_clusters": len(positive_clusters),
        "negative_clusters": len(negative_clusters),
    }

# Run parallel computation - will use all available cores
print("Processing layers in parallel...")
results = Parallel(n_jobs=NUM_JOBS, verbose=10)(
    delayed(process_layer)(
        layer,
        get_point_set(
            player_caches,
            get_layer_name(layer),
            extra_vectors=0,
            thickness=0.2,
        ),
        get_point_set(
            fake_caches,
            get_layer_name(layer),
            extra_vectors=0,
            thickness=0.2,
        ),
    )
    for layer in layers
)

# Process and display results
cluster_sizes = [[], []]
for result in sorted(results, key=lambda x: x["layer"]):
    layer = result["layer"]
    player_count = result["player_count"]
    fake_count = result["fake_count"]
    pos_clusters = result["positive_clusters"]
    neg_clusters = result["negative_clusters"]

    # Store results
    cluster_sizes[0].append(pos_clusters)
    cluster_sizes[1].append(neg_clusters)

    # Print layer info
    print(
        f"Layer {layer}: {player_count} player vectors collected, {fake_count} fake vectors collected."
    )
    print(f"{pos_clusters} clusters found.")
    print(f"{neg_clusters} negative clusters found.")

# Plot the number of clusters found across layers
plt.figure(figsize=(12, 6))
plt.plot(layers, cluster_sizes[0], marker="o", label="Positive Clusters")
plt.plot(layers, cluster_sizes[1], marker="s", label="Negative Clusters")
plt.xlabel("Layer")
plt.ylabel("Number of Clusters")
plt.title("Number of Clusters Found Across Layers")
plt.legend()
plt.grid(True, alpha=0.3)
