In [None]:
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
import warnings
import scipy
import pickle
from scipy.linalg import eigh, sqrtm, qr
import time

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def calculate_bert_params(layers, hidden_size):
    vocab_size = 30522
    max_position_embeddings = 512
    type_vocab_size = 2
    intermediate_size = 4 * hidden_size

    embedding_params = hidden_size * (
        vocab_size + max_position_embeddings + type_vocab_size
    )

    layer_params = (
        (4 * hidden_size**2) + (2 * hidden_size * intermediate_size) + (2 * hidden_size)
    )

    total_params = embedding_params + (layers * layer_params)
    return total_params / 1_000_000


BERT_MODELS = {
    "bert-2-128": ("google/bert_uncased_L-2_H-128_A-2", calculate_bert_params(2, 128)),
    "bert-2-256": ("google/bert_uncased_L-2_H-256_A-4", calculate_bert_params(2, 256)),
    "bert-2-512": ("google/bert_uncased_L-2_H-512_A-8", calculate_bert_params(2, 512)),
    "bert-2-768": ("google/bert_uncased_L-2_H-768_A-12", calculate_bert_params(2, 768)),
    "bert-4-128": ("google/bert_uncased_L-4_H-128_A-2", calculate_bert_params(4, 128)),
    "bert-4-256": (
        "google/bert_uncased_L-4_H-256_A-4",
        calculate_bert_params(4, 256),
    ),
    "bert-4-512": (
        "google/bert_uncased_L-4_H-512_A-8",
        calculate_bert_params(4, 512),
    ),
    "bert-4-768": ("google/bert_uncased_L-4_H-768_A-12", calculate_bert_params(4, 768)),
    "bert-6-128": ("google/bert_uncased_L-6_H-128_A-2", calculate_bert_params(6, 128)),
    "bert-6-256": ("google/bert_uncased_L-6_H-256_A-4", calculate_bert_params(6, 256)),
    "bert-6-512": ("google/bert_uncased_L-6_H-512_A-8", calculate_bert_params(6, 512)),
    "bert-6-768": ("google/bert_uncased_L-6_H-768_A-12", calculate_bert_params(6, 768)),
    "bert-8-128": ("google/bert_uncased_L-8_H-128_A-2", calculate_bert_params(8, 128)),
    "bert-8-256": ("google/bert_uncased_L-8_H-256_A-4", calculate_bert_params(8, 256)),
    "bert-8-512": (
        "google/bert_uncased_L-8_H-512_A-8",
        calculate_bert_params(8, 512),
    ),
    "bert-8-768": ("google/bert_uncased_L-8_H-768_A-12", calculate_bert_params(8, 768)),
    "bert-10-128": (
        "google/bert_uncased_L-10_H-128_A-2",
        calculate_bert_params(10, 128),
    ),
    "bert-10-256": (
        "google/bert_uncased_L-10_H-256_A-4",
        calculate_bert_params(10, 256),
    ),
    "bert-10-512": (
        "google/bert_uncased_L-10_H-512_A-8",
        calculate_bert_params(10, 512),
    ),
    "bert-10-768": (
        "google/bert_uncased_L-10_H-768_A-12",
        calculate_bert_params(10, 768),
    ),
    "bert-12-128": (
        "google/bert_uncased_L-12_H-128_A-2",
        calculate_bert_params(12, 128),
    ),
    "bert-12-256": (
        "google/bert_uncased_L-12_H-256_A-4",
        calculate_bert_params(12, 256),
    ),
    "bert-12-512": (
        "google/bert_uncased_L-12_H-512_A-8",
        calculate_bert_params(12, 512),
    ),
    "bert-12-768": ("bert-base-uncased", 110.0),
}

DATASETS = {
    "amazon_polarity": ("amazon_polarity", "test[:1000]"),
    "yelp_review": ("yelp_review_full", "test[:1000]"),
    "imdb": ("imdb", "test[:1000]"),
    "ag_news": ("ag_news", "test[:1000]"),
}


def stiefel_objective(X, Lambda1, Lambda2, v1, v2):
    term1 = np.trace(Lambda1 @ X @ Lambda2 @ X.T)
    term2 = 2 * v1.T @ X @ v2
    return term1 + term2


def stiefel_gradient(X, Lambda1, Lambda2, v1, v2):
    grad = 2 * Lambda1 @ X @ Lambda2 + 2 * np.outer(v1, v2)
    return grad


def stiefel_riemannian_gradient(X, euclidean_grad):
    XTgrad = X.T @ euclidean_grad
    sym_XTgrad = 0.5 * (XTgrad + XTgrad.T)
    riem_grad = euclidean_grad - X @ sym_XTgrad
    return riem_grad


def stiefel_retraction(X, direction, step_size):
    Y = X + step_size * direction

    if Y.shape[0] == Y.shape[1]:
        Q, R = qr(Y)
    else:
        Q, R = qr(Y, mode="economic")

    signs = np.sign(np.diag(R))
    signs[signs == 0] = 1
    Q = Q @ np.diag(signs)

    return Q


def optimize_stiefel(
    Lambda1,
    Lambda2,
    v1,
    v2,
    d1,
    d2,
    max_iter=50,
    tol=1e-2,
    step_size=0.01,
    verbose=False,
):
    np.random.seed(42)
    X = np.eye(d1, d2)

    objective_history = []

    for i in range(max_iter):
        obj = stiefel_objective(X, Lambda1, Lambda2, v1, v2)
        objective_history.append(obj)

        euclidean_grad = stiefel_gradient(X, Lambda1, Lambda2, v1, v2)
        riem_grad = stiefel_riemannian_gradient(X, euclidean_grad)

        grad_norm = np.linalg.norm(riem_grad)
        if grad_norm < tol:
            if verbose:
                print(f"Converged after {i} iterations")
            break

        X_new = stiefel_retraction(X, riem_grad, step_size)
        obj_new = stiefel_objective(X_new, Lambda1, Lambda2, v1, v2)

        backtrack_count = 0
        while obj_new < obj and step_size > 1e-12 and backtrack_count < 10:
            step_size *= 0.5
            X_new = stiefel_retraction(X, riem_grad, step_size)
            obj_new = stiefel_objective(X_new, Lambda1, Lambda2, v1, v2)
            backtrack_count += 1

        if obj_new >= obj:
            X = X_new
            step_size = min(step_size * 1.05, 0.1)
        else:
            step_size *= 0.5

    return X, objective_history


def compute_igw_distance(m1, m2, Sigma1, Sigma2, verbose=False):
    d1_orig, d2_orig = len(m1), len(m2)

    swapped = False
    if d1_orig < d2_orig:
        m1, m2 = m2, m1
        Sigma1, Sigma2 = Sigma2, Sigma1
        swapped = True

    d1, d2 = len(m1), len(m2)

    eig1 = np.linalg.eigvals(Sigma1)
    eig2 = np.linalg.eigvals(Sigma2)

    if not np.all(eig1 > 1e-12):
        Sigma1 += np.eye(d1) * 1e-6
    if not np.all(eig2 > 1e-12):
        Sigma2 += np.eye(d2) * 1e-6

    Lambda1_vals, Q1 = eigh(Sigma1)
    Lambda2_vals, Q2 = eigh(Sigma2)

    Lambda1_vals = np.maximum(Lambda1_vals, 1e-12)
    Lambda2_vals = np.maximum(Lambda2_vals, 1e-12)

    def safe_diag(vals):
        vals = np.atleast_1d(vals)
        if len(vals) == 1:
            return np.array([[vals[0]]])
        else:
            return np.diag(vals)

    Lambda1 = safe_diag(Lambda1_vals)
    Lambda2 = safe_diag(Lambda2_vals)

    m1_tilde = Q1.T @ m1
    m2_tilde = Q2.T @ m2

    v1 = np.sqrt(Lambda1_vals) * m1_tilde
    v2 = np.sqrt(Lambda2_vals) * m2_tilde

    X_opt, obj_history = optimize_stiefel(
        Lambda1, Lambda2, v1, v2, d1, d2, verbose=verbose
    )

    gamma_star = stiefel_objective(X_opt, Lambda1, Lambda2, v1, v2)

    term1 = np.sum(Lambda1_vals**2)
    term2 = np.sum(Lambda2_vals**2)
    term3 = 2 * np.sum(Lambda1_vals * m1_tilde**2)
    term4 = 2 * np.sum(Lambda2_vals * m2_tilde**2)
    term5 = (
        np.linalg.norm(m1_tilde) ** 2 - np.linalg.norm(m2_tilde) ** 2
    ) ** 2

    igw_distance_squared = term1 + term2 + term3 + term4 + term5 - 2 * gamma_star
    igw_distance = np.sqrt(max(0, igw_distance_squared))

    return igw_distance


def load_model_and_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()
    return model, tokenizer


def get_embeddings(texts, model, tokenizer, batch_size=32):
    embeddings = []

    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Embedding", leave=False):
            batch_texts = texts[i : i + batch_size]
            inputs = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            ).to(device)

            outputs = model(**inputs)
            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(batch_embeddings)

    return np.vstack(embeddings)


def project_to_common_dimension(embeddings, target_dim=128):
    from sklearn.decomposition import PCA

    if embeddings.shape[1] <= target_dim:
        if embeddings.shape[1] < target_dim:
            padding = np.zeros((embeddings.shape[0], target_dim - embeddings.shape[1]))
            embeddings = np.hstack([embeddings, padding])
        return embeddings
    else:
        pca = PCA(n_components=target_dim)
        return pca.fit_transform(embeddings)


def fit_gaussian(embeddings, target_dim=128):
    embeddings_proj = project_to_common_dimension(embeddings, target_dim)

    mean = np.mean(embeddings_proj, axis=0)
    cov = np.cov(embeddings_proj.T) + np.eye(embeddings_proj.shape[1]) * 1e-6
    return mean, cov


def compute_igw_bounds(mean1, cov1, mean2, cov2):
    d1_orig, d2_orig = len(mean1), len(mean2)

    if d1_orig < d2_orig:
        mean1, mean2 = mean2, mean1
        cov1, cov2 = cov2, cov1

    eig1 = np.linalg.eigvals(cov1)
    eig2 = np.linalg.eigvals(cov2)

    if not np.all(eig1 > 1e-12):
        cov1 += np.eye(len(mean1)) * 1e-6
    if not np.all(eig2 > 1e-12):
        cov2 += np.eye(len(mean2)) * 1e-6

    Lambda1_vals, Q1 = eigh(cov1)
    Lambda2_vals, Q2 = eigh(cov2)

    Lambda1_vals = np.maximum(Lambda1_vals, 1e-12)
    Lambda2_vals = np.maximum(Lambda2_vals, 1e-12)

    m1_tilde = Q1.T @ mean1
    m2_tilde = Q2.T @ mean2

    tr_Lambda1_sq = np.sum(Lambda1_vals**2)
    tr_Lambda2_sq = np.sum(Lambda2_vals**2)

    m1_tilde_Lambda1_m1_tilde = np.sum(Lambda1_vals * m1_tilde**2)
    m2_tilde_Lambda2_m2_tilde = np.sum(Lambda2_vals * m2_tilde**2)

    norm_m1_tilde_sq = np.linalg.norm(m1_tilde) ** 2
    norm_m2_tilde_sq = np.linalg.norm(m2_tilde) ** 2
    norm_diff_tilde_sq = (norm_m1_tilde_sq - norm_m2_tilde_sq) ** 2

    Lambda1_sorted = np.sort(Lambda1_vals)[::-1]
    Lambda2_sorted = np.sort(Lambda2_vals)[::-1]

    max_len = max(len(Lambda1_sorted), len(Lambda2_sorted))
    if len(Lambda1_sorted) < max_len:
        Lambda1_sorted = np.pad(Lambda1_sorted, (0, max_len - len(Lambda1_sorted)))
    if len(Lambda2_sorted) < max_len:
        Lambda2_sorted = np.pad(Lambda2_sorted, (0, max_len - len(Lambda2_sorted)))

    sum_eigenval_products = np.sum(Lambda1_sorted * Lambda2_sorted)

    xi = (
        tr_Lambda1_sq
        + tr_Lambda2_sq
        + 2 * m1_tilde_Lambda1_m1_tilde
        + 2 * m2_tilde_Lambda2_m2_tilde
        + norm_diff_tilde_sq
        - 2 * sum_eigenval_products
    )

    Lambda1_sqrt_m1_tilde = np.sqrt(Lambda1_vals) * m1_tilde
    Lambda2_sqrt_m2_tilde = np.sqrt(Lambda2_vals) * m2_tilde

    norm_product = np.linalg.norm(Lambda1_sqrt_m1_tilde) * np.linalg.norm(
        Lambda2_sqrt_m2_tilde
    )

    if len(Lambda1_sqrt_m1_tilde) == len(Lambda2_sqrt_m2_tilde):
        inner_product = Lambda1_sqrt_m1_tilde.T @ Lambda2_sqrt_m2_tilde
    else:
        max_dim = max(len(Lambda1_sqrt_m1_tilde), len(Lambda2_sqrt_m2_tilde))
        v1_padded = np.pad(
            Lambda1_sqrt_m1_tilde, (0, max(0, max_dim - len(Lambda1_sqrt_m1_tilde)))
        )
        v2_padded = np.pad(
            Lambda2_sqrt_m2_tilde, (0, max(0, max_dim - len(Lambda2_sqrt_m2_tilde)))
        )
        inner_product = v1_padded.T @ v2_padded

    gamma_upper_bound = sum_eigenval_products + 2 * norm_product
    gamma_lower_bound = sum_eigenval_products - 2 * norm_product

    lower_bound_sq = xi - 2 * gamma_upper_bound
    upper_bound_sq = xi - 2 * gamma_lower_bound

    lower_bound = np.sqrt(max(0, lower_bound_sq))
    upper_bound = np.sqrt(max(0, upper_bound_sq))

    return lower_bound, upper_bound


def analyze_dataset(dataset_name, dataset_config):
    try:
        if dataset_config[1].startswith("test["):
            split_name = dataset_config[1]
            dataset = load_dataset(dataset_config[0], split=split_name)
        else:
            dataset = load_dataset(dataset_config[0], split=dataset_config[1])
    except Exception as e:
        print(f"Error loading {dataset_name}: {e}")
        return None

    if "text" in dataset.features:
        texts = dataset["text"]
    elif "label" in dataset.features and len(dataset.features) == 2:
        text_column = [col for col in dataset.features if col != "label"][0]
        texts = dataset[text_column]
    else:
        text_columns = [
            col for col, feat in dataset.features.items() if feat.dtype == "string"
        ]
        if not text_columns:
            print(f"No text column found in {dataset_name}")
            return None
        texts = dataset[text_columns[0]]

    if len(texts) > 2000:
        texts = texts[:2000]

    model_sizes = []
    lower_bounds = []
    upper_bounds = []
    estimates = []
    model_data = {}

    all_embeddings = {}
    max_dim = 0

    for model_name, (model_path, param_count) in tqdm(
        BERT_MODELS.items(), desc="Embedding"
    ):
        try:
            model, tokenizer = load_model_and_tokenizer(model_path)

            embeddings = get_embeddings(texts, model, tokenizer)
            all_embeddings[model_name] = (embeddings, param_count)
            max_dim = max(max_dim, embeddings.shape[1])

            del model, tokenizer
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        except Exception as e:
            print(f"Error loading {model_name}: {e}")
            continue

    gaussians = {}
    for model_name, (embeddings, param_count) in all_embeddings.items():
        mean, cov = fit_gaussian(embeddings, max_dim)
        gaussians[model_name] = (mean, cov, param_count)

    if "bert-12-768" not in gaussians:
        ref_model = max(gaussians.keys(), key=lambda x: gaussians[x][2])
    else:
        ref_model = "bert-12-768"

    base_mean, base_cov, _ = gaussians[ref_model]
    print(f"Using {ref_model} as reference model")

    for model_name, (mean, cov, param_count) in gaussians.items():
        if model_name == ref_model:
            continue

        lower, upper = compute_igw_bounds(mean, cov, base_mean, base_cov)
        try:
            est = compute_igw_distance(
                mean, base_mean, cov, base_cov, verbose=False
            )
        except Exception as e:
            est = np.nan

        model_sizes.append(param_count)
        lower_bounds.append(lower)
        upper_bounds.append(upper)
        estimates.append(est)

        model_data[model_name] = {
            "mean": mean,
            "covariance": cov,
            "param_count": param_count,
            "lower_bound": lower,
            "upper_bound": upper,
            "estimate": est,
        }

        print(
            f"{model_name}: IGW bounds [{lower:.4f}, {upper:.4f}], Corrected IGW: {est:.4f}"
        )

    return {
        "dataset_name": dataset_name,
        "model_sizes": model_sizes,
        "lower_bounds": lower_bounds,
        "upper_bounds": upper_bounds,
        "estimates": estimates,
        "model_data": model_data,
        "reference_model": ref_model,
        "reference_mean": base_mean,
        "reference_covariance": base_cov,
    }


def save_results(results, filename="cache/bert_igw_results.pkl"):
    save_data = {
        "datasets": {},
        "bert_models": BERT_MODELS,
        "analysis_metadata": {
            "max_texts_per_dataset": 2000,
            "reference_model": None,
            "projection_dimension": None,
            "formula_version": "theorem_formula_with_transformed_means",
        },
    }

    for result in results:
        dataset_name = result["dataset_name"]
        save_data["datasets"][dataset_name] = {
            "model_data": result["model_data"],
            "reference_model": result["reference_model"],
            "reference_mean": result["reference_mean"],
            "reference_covariance": result["reference_covariance"],
            "summary_stats": {
                "model_sizes": result["model_sizes"],
                "lower_bounds": result["lower_bounds"],
                "upper_bounds": result["upper_bounds"],
                "estimates": result["estimates"],
            },
        }

        if save_data["analysis_metadata"]["reference_model"] is None:
            save_data["analysis_metadata"]["reference_model"] = result[
                "reference_model"
            ]
            save_data["analysis_metadata"]["projection_dimension"] = result[
                "reference_mean"
            ].shape[0]

    with open(filename, "wb") as f:
        pickle.dump(save_data, f)

results = []

for dataset_name, dataset_config in DATASETS.items():
    result = analyze_dataset(dataset_name, dataset_config)
    if result is not None:
        results.append(result)

if results:
    save_results(results)


if __name__ == "__main__":
    main()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
import warnings
from tqdm import tqdm
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import os

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


class MultiDatasetCKAAnalyzer:
    def __init__(self, results_file: str = "cache/bert_igw_results.pkl"):
        self.results_file = results_file
        self.load_existing_results()
        self.bert_models = self.existing_results["bert_models"]

    def load_existing_results(self):
        with open(self.results_file, "rb") as f:
            self.existing_results = pickle.load(f)

    def prepare_sample_dataset(self, dataset_name: str, max_samples: int = 1000):
        dataset_configs = {
            "amazon_polarity": ("amazon_polarity", "test[:1000]"),
            "yelp_review": ("yelp_review_full", "test[:1000]"),
            "imdb": ("imdb", "test[:1000]"),
            "ag_news": ("ag_news", "test[:1000]"),
        }

        if dataset_name not in dataset_configs:
            raise ValueError(f"Unknown dataset: {dataset_name}")

        dataset_path, split = dataset_configs[dataset_name]

        if split.startswith("test["):
            dataset = load_dataset(dataset_path, split=split)
        else:
            dataset = load_dataset(dataset_path, split=split)

        if len(dataset) > max_samples:
            dataset = dataset.shuffle(seed=42).select(range(max_samples))

        if "text" in dataset.features:
            texts = dataset["text"]
        elif "content" in dataset.features:
            texts = dataset["content"]
        elif "label" in dataset.features and len(dataset.features) == 2:
            text_column = [col for col in dataset.features if col != "label"][0]
            texts = dataset[text_column]
        else:
            text_columns = [
                col for col, feat in dataset.features.items() if feat.dtype == "string"
            ]
            if not text_columns:
                raise ValueError(f"No text column found in {dataset_name}")
            texts = dataset[text_columns[0]]

        print(f"Sample size: {len(texts)} texts")
        return texts

    def center_gram_matrix(self, K):
        n = K.shape[0]

        ones = torch.ones(n, 1, device=K.device)
        H = torch.eye(n, device=K.device) - (1 / n) * torch.mm(ones, ones.t())

        K_centered = torch.mm(torch.mm(H, K), H)
        return K_centered

    def compute_cka(self, X, Y):
        X = X.float()
        Y = Y.float()

        K_X = torch.mm(X, X.t())
        K_Y = torch.mm(Y, Y.t())

        K_X_centered = self.center_gram_matrix(K_X)
        K_Y_centered = self.center_gram_matrix(K_Y)

        numerator = torch.trace(torch.mm(K_X_centered, K_Y_centered))

        norm_X = torch.trace(torch.mm(K_X_centered, K_X_centered))
        norm_Y = torch.trace(torch.mm(K_Y_centered, K_Y_centered))
        denominator = torch.sqrt(norm_X * norm_Y)

        if denominator < 1e-12:
            return 0.0

        cka = numerator / denominator
        return cka.item()

    def extract_representations(
        self, model, tokenizer, texts, layer_idx=-2, max_length=512
    ):
        model.eval()
        representations = []

        with torch.no_grad():
            for text in tqdm(texts, desc="Extracting representations", leave=False):
                inputs = tokenizer(
                    text,
                    return_tensors="pt",
                    truncation=True,
                    padding="max_length",
                    max_length=max_length,
                ).to(model.device)

                outputs = model(**inputs, output_hidden_states=True)
                hidden_states = (
                    outputs.hidden_states
                )

                layer_output = hidden_states[
                    layer_idx
                ]
                cls_representation = layer_output[0, 0, :]

                representations.append(cls_representation.cpu())

        return torch.stack(representations)

    def load_model_safely(self, model_name, model_path):
        try:
            print(f"Loading {model_name}...")
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModel.from_pretrained(model_path)

            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            model.to(device)
            return model, tokenizer

        except Exception as e:
            print(f"Error loading {model_name}: {e}")
            return None, None

    def compute_all_cka_scores(
        self,
        dataset_name: str,
        max_models: int = None,
        reference_model: str = "bert-12-768",
        layer_idx: int = -2,
    ):
        texts = self.prepare_sample_dataset(dataset_name)

        dataset_results = self.existing_results["datasets"][dataset_name]
        model_data = dataset_results["model_data"]

        comparison_models = {
            k: v for k, v in model_data.items() if k != reference_model
        }

        model_items = list(comparison_models.items())
        model_items.sort(key=lambda x: x[1]["param_count"])

        if max_models is not None:
            model_items = model_items[:max_models]
            print(
                f"Only analyzing {max_models} smallest models"
            )

        reference_path = self.bert_models[reference_model][0]
        ref_model, ref_tokenizer = self.load_model_safely(
            reference_model, reference_path
        )

        if ref_model is None:
            print(f"Failed to load reference model {reference_model}")
            return None

        ref_representations = self.extract_representations(
            ref_model, ref_tokenizer, texts, layer_idx
        )

        del ref_model, ref_tokenizer
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        results = {
            "model_name": [],
            "param_count": [],
            "igw_distance": [],
            "cka_score": [],
            "load_success": [],
            "dataset": [],
        }

        for i, (model_name, data) in enumerate(model_items):
            print(f"\n--- Processing model {i+1}/{len(model_items)}: {model_name} ---")

            param_count = data["param_count"]

            igw_distance = (
                data["estimate"]
                if not np.isnan(data["estimate"])
                else data["upper_bound"]
            )
            model_path = self.bert_models[model_name][0]

            print(f"Parameters: {param_count:.1f}M")
            print(f"IGW Distance: {igw_distance:.4f}")

            model, tokenizer = self.load_model_safely(model_name, model_path)

            if model is None:
                results["model_name"].append(model_name)
                results["param_count"].append(param_count)
                results["igw_distance"].append(igw_distance)
                results["cka_score"].append(0.0)
                results["load_success"].append(False)
                results["dataset"].append(dataset_name)
                continue

            try:
                model_representations = self.extract_representations(
                    model, tokenizer, texts, layer_idx
                )

                cka_score = self.compute_cka(ref_representations, model_representations)

                print(f"CKA Score: {cka_score:.4f}")

                results["model_name"].append(model_name)
                results["param_count"].append(param_count)
                results["igw_distance"].append(igw_distance)
                results["cka_score"].append(cka_score)
                results["load_success"].append(True)
                results["dataset"].append(dataset_name)

            except Exception as e:
                print(f"Error computing CKA for {model_name}: {e}")
                results["model_name"].append(model_name)
                results["param_count"].append(param_count)
                results["igw_distance"].append(igw_distance)
                results["cka_score"].append(0.0)
                results["load_success"].append(False)
                results["dataset"].append(dataset_name)

            finally:
                del model, tokenizer
                torch.cuda.empty_cache() if torch.cuda.is_available() else None

        return pd.DataFrame(results)

    def save_cka_results(self, results_df: pd.DataFrame, dataset_name: str):
        csv_filename = f"cache/cka_igw_results_{dataset_name}.csv"
        pickle_filename = f"cache/cka_igw_results_{dataset_name}.pkl"

        results_df.to_csv(csv_filename, index=False)

        with open(pickle_filename, "wb") as f:
            pickle.dump(
                {
                    "results_dataframe": results_df,
                    "dataset_name": dataset_name,
                    "analysis_type": "CKA_vs_IGW",
                    "bert_models": self.bert_models,
                },
                f,
            )


analyzer = MultiDatasetCKAAnalyzer()

max_models = None
layer_idx = -1
datasets_to_analyze = ["amazon_polarity", "ag_news"]

all_results = {}

for i, dataset_name in enumerate(datasets_to_analyze):
    print(f"\n{'='*60}")
    print(
        f"Processing dataset {i+1}/{len(datasets_to_analyze)}: {dataset_name.upper()}"
    )
    print(f"{'='*60}")

    if dataset_name not in analyzer.existing_results["datasets"]:
        print(f"Warning: {dataset_name} not found in IGW results")
        continue

    results_df = analyzer.compute_all_cka_scores(
        dataset_name=dataset_name, max_models=max_models, layer_idx=layer_idx
    )

    if results_df is not None:
        analyzer.save_cka_results(results_df, dataset_name)
        all_results[dataset_name] = results_df

        successful = len(results_df[results_df["load_success"]])
        failed = len(results_df) - successful

        if successful > 0:
            successful_data = results_df[results_df["load_success"]]
            best_cka = successful_data.loc[successful_data["cka_score"].idxmax()]
            worst_cka = successful_data.loc[successful_data["cka_score"].idxmin()]

    else:
        print(f"Failed to process dataset: {dataset_name}")

# Plotting

In [None]:
import matplotlib.pyplot as plt
import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_single_dataset_igw(result, figsize=(8, 6)):
    sizes = np.array(
        [
            4.4,
            9.7,
            22.8,
            39.2,
            4.8,
            11.3,
            29.1,
            53.4,
            5.2,
            12.8,
            35.4,
            67.5,
            5.6,
            14.4,
            41.7,
            81.7,
            6.0,
            16.0,
            48.0,
            95.9,
            6.4,
            17.6,
            54.3,
            110.1,
        ]
    )

    lower = np.array(result["lower_bounds"])
    upper = np.array(result["upper_bounds"])
    corrected = np.array(result["estimates"])
    dataset_name = result["dataset_name"]

    n_models = len(lower)
    sizes = sizes[:n_models]

    sorted_indices = np.argsort(sizes)
    sizes = sizes[sorted_indices]
    lower = lower[sorted_indices]
    upper = upper[sorted_indices]
    corrected = corrected[sorted_indices]

    plt.figure()

    plt.plot(sizes, upper, "o--", label="Analytic upper bound")

    valid_mask = ~np.isnan(corrected)
    if np.any(valid_mask):
        plt.plot(
            sizes[valid_mask],
            corrected[valid_mask],
            "o-",
            label="RGD upper bound",
            markersize=5,
        )

        plt.plot(sizes, lower, "o-", label="Analytic lower bound")

        plt.fill_between(
            sizes[valid_mask],
            lower[valid_mask],
            corrected[valid_mask],
            alpha=0.3,
            label="IGW distance range",
        )

    plt.xlabel("Distillation size (millions of parameters)", fontsize=12)
    plt.ylabel("IGW distance from bert-base", fontsize=12)
    plt.title(f"IGW distance between distillations and bert-base ({dataset_name})", fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xscale("log")

    plt.tight_layout()
    plt.savefig(f"images/igw_distance_{dataset_name}.pdf")
    plt.show()


import pickle
with open('cache/bert_igw_results.pkl', 'rb') as f:
    saved_data = pickle.load(f)

dataset_name = 'ag_news'
if dataset_name in saved_data['datasets']:
    result = {
        'lower_bounds': saved_data['datasets'][dataset_name]['summary_stats']['lower_bounds'],
        'upper_bounds': saved_data['datasets'][dataset_name]['summary_stats']['upper_bounds'],
        'estimates': saved_data['datasets'][dataset_name]['summary_stats']['estimates'],
        'dataset_name': dataset_name
    }
    plot_single_dataset_igw(result)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
from scipy.stats import pearsonr, spearmanr


def plot_single_dataset_cka(result, figsize=(8, 6)):
    sizes = np.array(
        [
            4.4,
            9.7,
            22.8,
            39.2,
            4.8,
            11.3,
            29.1,
            53.4,
            5.2,
            12.8,
            35.4,
            67.5,
            5.6,
            14.4,
            41.7,
            81.7,
            6.0,
            16.0,
            48.0,
            95.9,
            6.4,
            17.6,
            54.3,
            110.1,
        ]
    )

    param_counts = np.array(result["param_counts"])
    igw_distances = np.array(result["igw_distances"])
    cka_scores = np.array(result["cka_scores"])
    dataset_name = result["dataset_name"]

    n_models = len(param_counts)
    sizes = sizes[:n_models]

    sorted_indices = np.argsort(sizes)
    sizes = sizes[sorted_indices]
    igw_distances = igw_distances[sorted_indices]
    cka_scores = cka_scores[sorted_indices]

    plt.figure()

    scatter = plt.scatter(
        igw_distances,
        cka_scores,
        c=sizes,
        s=60,
        alpha=0.7,
        cmap="viridis",
        edgecolors="black",
        linewidth=0.5,
    )

    if len(igw_distances) > 2:
        z = np.polyfit(igw_distances, cka_scores, 1)
        p = np.poly1d(z)
        x_trend = np.linspace(igw_distances.min(), igw_distances.max(), 100)
        plt.plot(
            x_trend,
            p(x_trend),
            "r--",
            alpha=0.8,
            linewidth=2,
            label="Trendline",
        )

        corr_pearson, p_pearson = pearsonr(igw_distances, cka_scores)
        corr_spearman, p_spearman = spearmanr(igw_distances, cka_scores)

        corr_text = f"Pearson $r$: {corr_pearson:.3f} ($p={p_pearson:.3f}$)\nSpearman $\\rho$: {corr_spearman:.3f} ($p={p_spearman:.3f}$)"
        plt.text(
            0.05,
            0.95,
            corr_text,
            transform=plt.gca().transAxes,
            verticalalignment="top",
            fontsize=12,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        )

    cbar = plt.colorbar(scatter)
    cbar.set_label("Model size (millions of parameters)", fontsize=12)

    plt.xlabel("IGW distance from bert-base (RGD)", fontsize=12)
    plt.ylabel("CKA similarity with bert-base", fontsize=12)
    plt.title(f"IGW distance vs. CKA similarity ({dataset_name})", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.ylim(top=0.7)

    plt.tight_layout()
    plt.savefig(f"images/cka_vs_igw_{dataset_name}.pdf")
    plt.show()


import pickle

dataset_name = "amazon_polarity"
pickle_filename = f"cache/cka_igw_results_{dataset_name}.pkl"

with open(pickle_filename, "rb") as f:
    cka_data = pickle.load(f)
    results_df = cka_data["results_dataframe"]

successful_results = results_df[results_df["load_success"]].copy()

if len(successful_results) > 0:
    result = {
        "param_counts": successful_results["param_count"].values,
        "igw_distances": successful_results["igw_distance"].values,
        "cka_scores": successful_results["cka_score"].values,
        "dataset_name": dataset_name,
    }
    plot_single_dataset_cka(result)