**bold text**

# setup

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
%%capture
!pip install git+https://github.com/EleutherAI/sae.git

In [None]:
# you should load this before cloning repo files
# from .config import SaeConfig
# from .utils import decoder_impl

from sae.config import SaeConfig
from sae.utils import decoder_impl
from sae import Sae

Triton not installed, using eager implementation of SAE decoder.


In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import json
from fnmatch import fnmatch
from pathlib import Path
from typing import NamedTuple, Optional, Callable, Union, List, Tuple
# from jaxtyping import Float, Int

import einops
import torch
from torch import Tensor, nn
from huggingface_hub import snapshot_download
from natsort import natsorted
from safetensors.torch import load_model, save_model

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from collections import Counter
import pandas as pd
from IPython.display import display

## corr fns

In [None]:
def batched_correlation(reshaped_activations_A, reshaped_activations_B, batch_size=100):
    # Ensure tensors are on GPU
    if torch.cuda.is_available():
        reshaped_activations_A = reshaped_activations_A.to('cuda')
        reshaped_activations_B = reshaped_activations_B.to('cuda')

    # Normalize columns of A
    mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)
    std_A = reshaped_activations_A.std(dim=0, keepdim=True)
    normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero

    # Normalize columns of B
    mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)
    std_B = reshaped_activations_B.std(dim=0, keepdim=True)
    normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero

    num_batches = (normalized_B.shape[1] + batch_size - 1) // batch_size
    max_values = []
    max_indices = []

    for batch in range(num_batches):
        start = batch * batch_size
        end = min(start + batch_size, normalized_B.shape[1])
        batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, start:end]) / normalized_A.shape[0]
        max_val, max_idx = batch_corr_matrix.max(dim=0)
        max_values.append(max_val)
        max_indices.append(max_idx)

        del batch_corr_matrix
        torch.cuda.empty_cache()

    corr_inds = torch.cat(max_indices).detach().cpu().numpy()
    corr_vals = torch.cat(max_values).detach().cpu().numpy()
    return corr_inds, corr_vals

In [None]:
def filter_corr_pairs(mixed_modA_feats, mixed_modB_feats, kept_modA_feats):
    filt_corr_ind_A = []
    filt_corr_ind_B = []
    seen = set()
    for ind_A, ind_B in zip(mixed_modA_feats, mixed_modB_feats):
        if ind_A in kept_modA_feats:
            filt_corr_ind_A.append(ind_A)
            filt_corr_ind_B.append(ind_B)
        elif ind_A not in seen:  # only keep one if it's over count X
            seen.add(ind_A)
            filt_corr_ind_A.append(ind_A)
            filt_corr_ind_B.append(ind_B)
    num_unq_pairs = len(list(set(filt_corr_ind_A)))
    print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
    print("num 1-1 feats after filt: ", num_unq_pairs )
    return filt_corr_ind_A, filt_corr_ind_B, num_unq_pairs

In [None]:
def get_new_mean_corr(modA_feats, modB_feats, corr_vals):
    new_vals = []
    seen = set()
    for ind_A, ind_B in zip(modA_feats, modB_feats):
        if ind_A not in seen:
            seen.add(ind_A)
            val = corr_vals[ind_B]
            new_vals.append(val)
    new_mean_corr = sum(new_vals) / len(new_vals)
    # print(new_mean_corr)
    return new_mean_corr

## sim fns

In [None]:
import functools
from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np
import numpy.typing as npt
import torch


def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        return x if isinstance(x, np.ndarray) else x.numpy()

    return list(map(convert, args))


def to_torch_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[torch.Tensor]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
        return x if isinstance(x, torch.Tensor) else torch.from_numpy(x)

    return list(map(convert, args))


def adjust_dimensionality(
    R: npt.NDArray, Rp: npt.NDArray, strategy="zero_pad"
) -> Tuple[npt.NDArray, npt.NDArray]:
    D = R.shape[1]
    Dp = Rp.shape[1]
    if strategy == "zero_pad":
        if D - Dp == 0:
            return R, Rp
        elif D - Dp > 0:
            return R, np.concatenate((Rp, np.zeros((Rp.shape[0], D - Dp))), axis=1)
        else:
            return np.concatenate((R, np.zeros((R.shape[0], Dp - D))), axis=1), Rp
    else:
        raise NotImplementedError()


def center_columns(R: npt.NDArray) -> npt.NDArray:
    return R - R.mean(axis=0)[None, :]


def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord="fro")


def sim_random_baseline(
    rep1: torch.Tensor, rep2: torch.Tensor, sim_func: Callable, n_permutations: int = 10
) -> Dict[str, Any]:
    torch.manual_seed(1234)
    scores = []
    for _ in range(n_permutations):
        perm = torch.randperm(rep1.size(0))

        score = sim_func(rep1[perm, :], rep2)
        score = score if isinstance(score, float) else score["score"]

        scores.append(score)

    return {"baseline_scores": np.array(scores)}


class Pipeline:
    def __init__(
        self,
        preprocess_funcs: List[Callable[[npt.NDArray], npt.NDArray]],
        similarity_func: Callable[[npt.NDArray, npt.NDArray], Dict[str, Any]],
    ) -> None:
        self.preprocess_funcs = preprocess_funcs
        self.similarity_func = similarity_func

    def __call__(self, R: npt.NDArray, Rp: npt.NDArray) -> Dict[str, Any]:
        for preprocess_func in self.preprocess_funcs:
            R = preprocess_func(R)
            Rp = preprocess_func(Rp)
        return self.similarity_func(R, Rp)

    def __str__(self) -> str:
        def func_name(func: Callable) -> str:
            return (
                func.__name__
                if not isinstance(func, functools.partial)
                else func.func.__name__
            )

        def partial_keywords(func: Callable) -> str:
            if not isinstance(func, functools.partial):
                return ""
            else:
                return str(func.keywords)

        return (
            "Pipeline("
            + (
                "+".join(map(func_name, self.preprocess_funcs))
                + "+"
                + func_name(self.similarity_func)
                + partial_keywords(self.similarity_func)
            )
            + ")"
        )

In [None]:
from typing import List, Set, Union

import numpy as np
import numpy.typing as npt
import sklearn.neighbors
import torch

# from llmcomp.measures.utils import to_numpy_if_needed


def _jac_sim_i(idx_R: Set[int], idx_Rp: Set[int]) -> float:
    return len(idx_R.intersection(idx_Rp)) / len(idx_R.union(idx_Rp))


def jaccard_similarity(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    k: int = 10,
    inner: str = "cosine",
    n_jobs: int = 8,
) -> float:
    R, Rp = to_numpy_if_needed(R, Rp)

    indices_R = nn_array_to_setlist(top_k_neighbors(R, k, inner, n_jobs))
    indices_Rp = nn_array_to_setlist(top_k_neighbors(Rp, k, inner, n_jobs))

    return float(
        np.mean(
            [_jac_sim_i(idx_R, idx_Rp) for idx_R, idx_Rp in zip(indices_R, indices_Rp)]
        )
    )


def top_k_neighbors(
    R: npt.NDArray,
    k: int,
    inner: str,
    n_jobs: int,
) -> npt.NDArray:
    # k+1 nearest neighbors, because we pass in all the data, which means that a point
    # will be the nearest neighbor to itself. We remove this point from the results and
    # report only the k nearest neighbors distinct from the point itself.
    nns = sklearn.neighbors.NearestNeighbors(
        n_neighbors=k + 1, metric=inner, n_jobs=n_jobs
    )
    nns.fit(R)
    _, nns = nns.kneighbors(R)
    return nns[:, 1:]


def nn_array_to_setlist(nn: npt.NDArray) -> List[Set[int]]:
    return [set(idx) for idx in nn]

In [None]:
import functools
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import get_args
from typing import List
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import Union

import numpy as np
import numpy.typing as npt
import torch
from einops import rearrange
# from loguru import logger

log = logging.getLogger(__name__)


SHAPE_TYPE = Literal["nd", "ntd", "nchw"]

ND_SHAPE, NTD_SHAPE, NCHW_SHAPE = get_args(SHAPE_TYPE)[0], get_args(SHAPE_TYPE)[1], get_args(SHAPE_TYPE)[2]


class SimilarityFunction(Protocol):
    def __call__(  # noqa: E704
        self,
        R: torch.Tensor | npt.NDArray,
        Rp: torch.Tensor | npt.NDArray,
        shape: SHAPE_TYPE,
    ) -> float: ...


class RSMSimilarityFunction(Protocol):
    def __call__(  # noqa: E704
        self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE, n_jobs: int
    ) -> float: ...


@dataclass
class BaseSimilarityMeasure(ABC):
    larger_is_more_similar: bool
    is_symmetric: bool

    is_metric: bool | None = None
    invariant_to_affine: bool | None = None
    invariant_to_invertible_linear: bool | None = None
    invariant_to_ortho: bool | None = None
    invariant_to_permutation: bool | None = None
    invariant_to_isotropic_scaling: bool | None = None
    invariant_to_translation: bool | None = None
    name: str = field(init=False)

    def __post_init__(self):
        self.name = self.__class__.__name__

    @abstractmethod
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        raise NotImplementedError


class FunctionalSimilarityMeasure(BaseSimilarityMeasure):
    @abstractmethod
    def __call__(self, output_a: torch.Tensor | npt.NDArray, output_b: torch.Tensor | npt.NDArray) -> float:
        raise NotImplementedError


@dataclass(kw_only=True)
class RepresentationalSimilarityMeasure(BaseSimilarityMeasure):
    sim_func: SimilarityFunction

    def __call__(
        self,
        R: torch.Tensor | npt.NDArray,
        Rp: torch.Tensor | npt.NDArray,
        shape: SHAPE_TYPE,
    ) -> float:
        return self.sim_func(R, Rp, shape)


class RSMSimilarityMeasure(RepresentationalSimilarityMeasure):
    sim_func: RSMSimilarityFunction

    @staticmethod
    def estimate_good_number_of_jobs(R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray) -> int:
        # RSMs in are NxN (or DxD) so the number of jobs should roughly scale quadratically with increase in N (or D).
        # False! As long as sklearn-native metrics are used, they will use parallel implementations regardless of job
        # count. Each job would spawn their own threads, which leads to oversubscription of cores and thus slowdown.
        # This seems to be not fully correct (n_jobs=2 seems to actually use two cores), but using n_jobs=1 seems the
        # fastest.
        return 1

    def __call__(
        self,
        R: torch.Tensor | npt.NDArray,
        Rp: torch.Tensor | npt.NDArray,
        shape: SHAPE_TYPE,
        n_jobs: Optional[int] = None,
    ) -> float:
        if n_jobs is None:
            n_jobs = self.estimate_good_number_of_jobs(R, Rp)
        return self.sim_func(R, Rp, shape, n_jobs=n_jobs)


def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        return x if isinstance(x, np.ndarray) else x.numpy()

    return list(map(convert, args))


def to_torch_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[torch.Tensor]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
        return x if isinstance(x, torch.Tensor) else torch.from_numpy(x)

    return list(map(convert, args))


def adjust_dimensionality(R: npt.NDArray, Rp: npt.NDArray, strategy="zero_pad") -> Tuple[npt.NDArray, npt.NDArray]:
    D = R.shape[1]
    Dp = Rp.shape[1]
    if strategy == "zero_pad":
        if D - Dp == 0:
            return R, Rp
        elif D - Dp > 0:
            return R, np.concatenate((Rp, np.zeros((Rp.shape[0], D - Dp))), axis=1)
        else:
            return np.concatenate((R, np.zeros((R.shape[0], Dp - D))), axis=1), Rp
    else:
        raise NotImplementedError()


def center_columns(R: npt.NDArray) -> npt.NDArray:
    return R - R.mean(axis=0)[None, :]


def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord="fro")


def normalize_row_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord=2, axis=1, keepdims=True)


def standardize(R: npt.NDArray) -> npt.NDArray:
    return (R - R.mean(axis=0, keepdims=True)) / R.std(axis=0)


def double_center(x: npt.NDArray) -> npt.NDArray:
    return x - x.mean(axis=0, keepdims=True) - x.mean(axis=1, keepdims=True) + x.mean()


def align_spatial_dimensions(R: npt.NDArray, Rp: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]:
    """
    Aligns spatial representations by resizing them to the smallest spatial dimension.
    Subsequent aligned spatial representations are flattened, with the spatial aligned representations
    moving into the *sample* dimension.
    """
    R_re, Rp_re = resize_wh_reps(R, Rp)
    R_re = rearrange(R_re, "n c h w -> (n h w) c")
    Rp_re = rearrange(Rp_re, "n c h w -> (n h w) c")
    if R_re.shape[0] > 5000:
        logger.info(f"Got {R_re.shape[0]} samples in N after flattening. Subsampling to reduce compute.")
        subsample = R_re.shape[0] // 5000
        R_re = R_re[::subsample]
        Rp_re = Rp_re[::subsample]

    return R_re, Rp_re


def average_pool_downsample(R, resize: bool, new_size: tuple[int, int]):
    if not resize:
        return R  # do nothing
    else:
        is_numpy = isinstance(R, np.ndarray)
        R_torch = torch.from_numpy(R) if is_numpy else R
        R_torch = torch.nn.functional.adaptive_avg_pool2d(R_torch, new_size)
        return R_torch.numpy() if is_numpy else R_torch


def resize_wh_reps(R: npt.NDArray, Rp: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]:
    """
    Function for resizing spatial representations that are not the same size.
    Does through fourier transform and resizing.

    Args:
        R: numpy array of shape  [batch_size, height, width, num_channels]
        RP: numpy array of shape [batch_size, height, width, num_channels]

    Returns:
        fft_acts1: numpy array of shape [batch_size, (new) height, (new) width, num_channels]
        fft_acts2: numpy array of shape [batch_size, (new) height, (new) width, num_channels]

    """
    height1, width1 = R.shape[2], R.shape[3]
    height2, width2 = Rp.shape[2], Rp.shape[3]
    if height1 != height2 or width1 != width2:
        height = min(height1, height2)
        width = min(width1, width2)
        new_size = [height, width]
        resize = True
    else:
        height = height1
        width = width1
        new_size = None
        resize = False

    # resize and preprocess with fft
    avg_ds1 = average_pool_downsample(R, resize=resize, new_size=new_size)
    avg_ds2 = average_pool_downsample(Rp, resize=resize, new_size=new_size)
    return avg_ds1, avg_ds2


def fft_resize(images, resize=False, new_size=None):
    """Function for applying DFT and resizing.

    This function takes in an array of images, applies the 2-d fourier transform
    and resizes them according to new_size, keeping the frequencies that overlap
    between the two sizes.

    Args:
              images: a numpy array with shape
                      [batch_size, height, width, num_channels]
              resize: boolean, whether or not to resize
              new_size: a tuple (size, size), with height and width the same

    Returns:
              im_fft_downsampled: a numpy array with shape
                           [batch_size, (new) height, (new) width, num_channels]
    """
    assert len(images.shape) == 4, "expecting images to be" "[batch_size, height, width, num_channels]"
    if resize:
        # FFT --> remove high frequencies --> inverse FFT
        im_complex = images.astype("complex64")
        im_fft = np.fft.fft2(im_complex, axes=(1, 2))
        im_shifted = np.fft.fftshift(im_fft, axes=(1, 2))

        center_width = im_shifted.shape[2] // 2
        center_height = im_shifted.shape[1] // 2
        half_w = new_size[0] // 2
        half_h = new_size[1] // 2
        cropped_fft = im_shifted[
            :, center_height - half_h : center_height + half_h, center_width - half_w : center_width + half_w, :
        ]
        cropped_fft_shifted_back = np.fft.ifft2(cropped_fft, axes=(1, 2))
        return cropped_fft_shifted_back.real
    else:
        return images


class Pipeline:
    def __init__(
        self,
        preprocess_funcs: List[Callable[[npt.NDArray], npt.NDArray]],
        similarity_func: Callable[[npt.NDArray, npt.NDArray, SHAPE_TYPE], float],
    ) -> None:
        self.preprocess_funcs = preprocess_funcs
        self.similarity_func = similarity_func

    def __call__(self, R: npt.NDArray, Rp: npt.NDArray, shape: SHAPE_TYPE) -> float:
        try:
            for preprocess_func in self.preprocess_funcs:
                R = preprocess_func(R)
                Rp = preprocess_func(Rp)
            return self.similarity_func(R, Rp, shape)
        except ValueError as e:
            log.info(f"Pipeline failed: {e}")
            return np.nan

    def __str__(self) -> str:
        def func_name(func: Callable) -> str:
            return func.__name__ if not isinstance(func, functools.partial) else func.func.__name__

        def partial_keywords(func: Callable) -> str:
            if not isinstance(func, functools.partial):
                return ""
            else:
                return str(func.keywords)

        return (
            "Pipeline("
            + (
                "+".join(map(func_name, self.preprocess_funcs))
                + "+"
                + func_name(self.similarity_func)
                + partial_keywords(self.similarity_func)
            )
            + ")"
        )


def flatten(*args: Union[torch.Tensor, npt.NDArray], shape: SHAPE_TYPE) -> List[Union[torch.Tensor, npt.NDArray]]:
    if shape == "ntd":
        return list(map(flatten_nxtxd_to_ntxd, args))
    elif shape == "nd":
        return list(args)
    elif shape == "nchw":
        return list(map(flatten_nxcxhxw_to_nxchw, args))  # Flattening non-trivial for nchw
    else:
        raise ValueError("Unknown shape of representations. Must be one of 'ntd', 'nchw', 'nd'.")


def flatten_nxtxd_to_ntxd(R: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
    R = to_torch_if_needed(R)[0]
    log.debug("Shape before flattening: %s", str(R.shape))
    R = torch.flatten(R, start_dim=0, end_dim=1)
    log.debug("Shape after flattening: %s", str(R.shape))
    return R


def flatten_nxcxhxw_to_nxchw(R: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
    R = to_torch_if_needed(R)[0]
    log.debug("Shape before flattening: %s", str(R.shape))
    R = torch.reshape(R, (R.shape[0], -1))
    log.debug("Shape after flattening: %s", str(R.shape))
    return R

In [None]:
import scipy.optimize

def permutation_procrustes(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
    optimal_permutation_alignment: Optional[Tuple[npt.NDArray, npt.NDArray]] = None,
) -> float:
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)
    R, Rp = adjust_dimensionality(R, Rp)

    if not optimal_permutation_alignment:
        PR, PRp = scipy.optimize.linear_sum_assignment(R.T @ Rp, maximize=True)  # returns column assignments
        optimal_permutation_alignment = (PR, PRp)
    PR, PRp = optimal_permutation_alignment
    return float(np.linalg.norm(R[:, PR] - Rp[:, PRp], ord="fro"))

In [None]:
from typing import Optional
from typing import Union

import numpy as np
import numpy.typing as npt
import scipy.spatial.distance
import scipy.stats
import sklearn.metrics
import torch
# from repsim.measures.utils import flatten
# from repsim.measures.utils import RSMSimilarityMeasure
# from repsim.measures.utils import SHAPE_TYPE
# from repsim.measures.utils import to_numpy_if_needed


def representational_similarity_analysis(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
    inner="correlation",
    outer="spearman",
    n_jobs: Optional[int] = None,
) -> float:
    """Representational similarity analysis

    Args:
        R (Union[torch.Tensor, npt.NDArray]): N x D representation
        Rp (Union[torch.Tensor, npt.NDArray]): N x D' representation
        inner (str, optional): inner similarity function for RSM. Must be one of
            scipy.spatial.distance.pdist identifiers . Defaults to "correlation".
        outer (str, optional): outer similarity function that compares RSMs. Defaults to
             "spearman". Must be one of "spearman", "euclidean"

    Returns:
        float: _description_
    """
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)

    if inner == "correlation":
        # n_jobs only works if metric is in PAIRWISE_DISTANCES as defined in sklearn, i.e., not for correlation.
        # But correlation = 1 - cosine dist of row-centered data, so we use the faster cosine metric and center the data.
        R = R - R.mean(axis=1, keepdims=True)
        S = scipy.spatial.distance.squareform(  # take the lower triangle of RSM
            1 - sklearn.metrics.pairwise_distances(R, metric="cosine", n_jobs=n_jobs),  # type:ignore
            checks=False,
        )
        Rp = Rp - Rp.mean(axis=1, keepdims=True)
        Sp = scipy.spatial.distance.squareform(
            1 - sklearn.metrics.pairwise_distances(Rp, metric="cosine", n_jobs=n_jobs),  # type:ignore
            checks=False,
        )
    elif inner == "euclidean":
        # take the lower triangle of RSM
        S = scipy.spatial.distance.squareform(
            sklearn.metrics.pairwise_distances(R, metric=inner, n_jobs=n_jobs), checks=False
        )
        Sp = scipy.spatial.distance.squareform(
            sklearn.metrics.pairwise_distances(Rp, metric=inner, n_jobs=n_jobs), checks=False
        )
    else:
        raise NotImplementedError(f"{inner=}")

    if outer == "spearman":
        return scipy.stats.spearmanr(S, Sp).statistic  # type:ignore
    elif outer == "euclidean":
        return float(np.linalg.norm(S - Sp, ord=2))
    else:
        raise ValueError(f"Unknown outer similarity function: {outer}")


class RSA(RSMSimilarityMeasure):
    def __init__(self):
        # choice of inner/outer in __call__ if fixed to default values, so these values are always the same
        super().__init__(
            sim_func=representational_similarity_analysis,
            larger_is_more_similar=True,
            is_metric=False,
            is_symmetric=True,
            invariant_to_affine=False,
            invariant_to_invertible_linear=False,
            invariant_to_ortho=False,
            invariant_to_permutation=True,
            invariant_to_isotropic_scaling=True,
            invariant_to_translation=True,
        )

In [None]:
##################################################################################
# Copied from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/cca_core.py
# Copyright 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The core code for applying Canonical Correlation Analysis to deep networks.

This module contains the core functions to apply canonical correlation analysis
to deep neural networks. The main function is get_cca_similarity, which takes in
two sets of activations, typically the neurons in two layers and their outputs
on all of the datapoints D = [d_1,...,d_m] that have been passed through.

Inputs have shape (num_neurons1, m), (num_neurons2, m). This can be directly
applied used on fully connected networks. For convolutional layers, the 3d block
of neurons can either be flattened entirely, along channels, or alternatively,
the dft_ccas (Discrete Fourier Transform) module can be used.

See:
https://arxiv.org/abs/1706.05806
https://arxiv.org/abs/1806.05759
for full details.

"""
import numpy as np
# from repsim.measures.utils import align_spatial_dimensions

num_cca_trials = 5


def positivedef_matrix_sqrt(array):
    """Stable method for computing matrix square roots, supports complex matrices.

    Args:
              array: A numpy 2d array, can be complex valued that is a positive
                     definite symmetric (or hermitian) matrix

    Returns:
              sqrtarray: The matrix square root of array
    """
    w, v = np.linalg.eigh(array)
    #  A - np.dot(v, np.dot(np.diag(w), v.T))
    wsqrt = np.sqrt(w)
    sqrtarray = np.dot(v, np.dot(np.diag(wsqrt), np.conj(v).T))
    return sqrtarray


def remove_small(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon):
    """Takes covariance between X, Y, and removes values of small magnitude.

    Args:
              sigma_xx: 2d numpy array, variance matrix for x
              sigma_xy: 2d numpy array, crossvariance matrix for x,y
              sigma_yx: 2d numpy array, crossvariance matrixy for x,y,
                        (conjugate) transpose of sigma_xy
              sigma_yy: 2d numpy array, variance matrix for y
              epsilon : cutoff value for norm below which directions are thrown
                         away

    Returns:
              sigma_xx_crop: 2d array with low x norm directions removed
              sigma_xy_crop: 2d array with low x and y norm directions removed
              sigma_yx_crop: 2d array with low x and y norm directiosn removed
              sigma_yy_crop: 2d array with low y norm directions removed
              x_idxs: indexes of sigma_xx that were removed
              y_idxs: indexes of sigma_yy that were removed
    """

    x_diag = np.abs(np.diagonal(sigma_xx))
    y_diag = np.abs(np.diagonal(sigma_yy))
    x_idxs = x_diag >= epsilon
    y_idxs = y_diag >= epsilon

    sigma_xx_crop = sigma_xx[x_idxs][:, x_idxs]
    sigma_xy_crop = sigma_xy[x_idxs][:, y_idxs]
    sigma_yx_crop = sigma_yx[y_idxs][:, x_idxs]
    sigma_yy_crop = sigma_yy[y_idxs][:, y_idxs]

    return (sigma_xx_crop, sigma_xy_crop, sigma_yx_crop, sigma_yy_crop, x_idxs, y_idxs)


def compute_ccas(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon, verbose=True):
    """Main cca computation function, takes in variances and crossvariances.

    This function takes in the covariances and cross covariances of X, Y,
    preprocesses them (removing small magnitudes) and outputs the raw results of
    the cca computation, including cca directions in a rotated space, and the
    cca correlation coefficient values.

    Args:
              sigma_xx: 2d numpy array, (num_neurons_x, num_neurons_x)
                        variance matrix for x
              sigma_xy: 2d numpy array, (num_neurons_x, num_neurons_y)
                        crossvariance matrix for x,y
              sigma_yx: 2d numpy array, (num_neurons_y, num_neurons_x)
                        crossvariance matrix for x,y (conj) transpose of sigma_xy
              sigma_yy: 2d numpy array, (num_neurons_y, num_neurons_y)
                        variance matrix for y
              epsilon:  small float to help with stabilizing computations
              verbose:  boolean on whether to print intermediate outputs

    Returns:
              [ux, sx, vx]: [numpy 2d array, numpy 1d array, numpy 2d array]
                            ux and vx are (conj) transposes of each other, being
                            the canonical directions in the X subspace.
                            sx is the set of canonical correlation coefficients-
                            how well corresponding directions in vx, Vy correlate
                            with each other.
              [uy, sy, vy]: Same as above, but for Y space
              invsqrt_xx:   Inverse square root of sigma_xx to transform canonical
                            directions back to original space
              invsqrt_yy:   Same as above but for sigma_yy
              x_idxs:       The indexes of the input sigma_xx that were pruned
                            by remove_small
              y_idxs:       Same as above but for sigma_yy
    """

    (sigma_xx, sigma_xy, sigma_yx, sigma_yy, x_idxs, y_idxs) = remove_small(
        sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon
    )

    numx = sigma_xx.shape[0]
    numy = sigma_yy.shape[0]

    if numx == 0 or numy == 0:
        return (
            [0, 0, 0],
            [0, 0, 0],
            np.zeros_like(sigma_xx),
            np.zeros_like(sigma_yy),
            x_idxs,
            y_idxs,
        )

    if verbose:
        print("adding eps to diagonal and taking inverse")
    sigma_xx += epsilon * np.eye(numx)
    sigma_yy += epsilon * np.eye(numy)
    inv_xx = np.linalg.pinv(sigma_xx)
    inv_yy = np.linalg.pinv(sigma_yy)

    if verbose:
        print("taking square root")
    invsqrt_xx = positivedef_matrix_sqrt(inv_xx)
    invsqrt_yy = positivedef_matrix_sqrt(inv_yy)

    if verbose:
        print("dot products...")
    arr = np.dot(invsqrt_xx, np.dot(sigma_xy, invsqrt_yy))

    if verbose:
        print("trying to take final svd")
    u, s, v = np.linalg.svd(arr)

    if verbose:
        print("computed everything!")

    return [u, np.abs(s), v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs


def sum_threshold(array, threshold):
    """Computes threshold index of decreasing nonnegative array by summing.

    This function takes in a decreasing array nonnegative floats, and a
    threshold between 0 and 1. It returns the index i at which the sum of the
    array up to i is threshold*total mass of the array.

    Args:
              array: a 1d numpy array of decreasing, nonnegative floats
              threshold: a number between 0 and 1

    Returns:
              i: index at which np.sum(array[:i]) >= threshold
    """
    assert (threshold >= 0) and (threshold <= 1), "print incorrect threshold"

    for i in range(len(array)):
        if np.sum(array[:i]) / np.sum(array) >= threshold:
            return i


def create_zero_dict(compute_dirns, dimension):
    """Outputs a zero dict when neuron activation norms too small.

    This function creates a return_dict with appropriately shaped zero entries
    when all neuron activations are very small.

    Args:
              compute_dirns: boolean, whether to have zero vectors for directions
              dimension: int, defines shape of directions

    Returns:
              return_dict: a dict of appropriately shaped zero entries
    """
    return_dict = {}
    return_dict["mean"] = (np.asarray(0), np.asarray(0))
    return_dict["sum"] = (np.asarray(0), np.asarray(0))
    return_dict["cca_coef1"] = np.asarray(0)
    return_dict["cca_coef2"] = np.asarray(0)
    return_dict["idx1"] = 0
    return_dict["idx2"] = 0

    if compute_dirns:
        return_dict["cca_dirns1"] = np.zeros((1, dimension))
        return_dict["cca_dirns2"] = np.zeros((1, dimension))

    return return_dict


def get_cca_similarity(
    acts1,
    acts2,
    epsilon=0.0,
    threshold=0.98,
    compute_coefs=True,
    compute_dirns=False,
    verbose=True,
):
    """The main function for computing cca similarities.

    This function computes the cca similarity between two sets of activations,
    returning a dict with the cca coefficients, a few statistics of the cca
    coefficients, and (optionally) the actual directions.

    Args:
              acts1: (num_neurons1, data_points) a 2d numpy array of neurons by
                     datapoints where entry (i,j) is the output of neuron i on
                     datapoint j.
              acts2: (num_neurons2, data_points) same as above, but (potentially)
                     for a different set of neurons. Note that acts1 and acts2
                     can have different numbers of neurons, but must agree on the
                     number of datapoints

              epsilon: small float to help stabilize computations

              threshold: float between 0, 1 used to get rid of trailing zeros in
                         the cca correlation coefficients to output more accurate
                         summary statistics of correlations.


              compute_coefs: boolean value determining whether coefficients
                             over neurons are computed. Needed for computing
                             directions

              compute_dirns: boolean value determining whether actual cca
                             directions are computed. (For very large neurons and
                             datasets, may be better to compute these on the fly
                             instead of store in memory.)

              verbose: Boolean, whether intermediate outputs are printed

    Returns:
              return_dict: A dictionary with outputs from the cca computations.
                           Contains neuron coefficients (combinations of neurons
                           that correspond to cca directions), the cca correlation
                           coefficients (how well aligned directions correlate),
                           x and y idxs (for computing cca directions on the fly
                           if compute_dirns=False), and summary statistics. If
                           compute_dirns=True, the cca directions are also
                           computed.
    """

    # assert dimensionality equal
    assert acts1.shape[1] == acts2.shape[1], "dimensions don't match"
    # check that acts1, acts2 are transposition
    assert acts1.shape[0] < acts1.shape[1], "input must be number of neurons" "by datapoints"
    return_dict = {}

    # compute covariance with numpy function for extra stability
    numx = acts1.shape[0]
    numy = acts2.shape[0]

    covariance = np.cov(acts1, acts2)
    sigmaxx = covariance[:numx, :numx]
    sigmaxy = covariance[:numx, numx:]
    sigmayx = covariance[numx:, :numx]
    sigmayy = covariance[numx:, numx:]

    # rescale covariance to make cca computation more stable
    xmax = np.max(np.abs(sigmaxx))
    ymax = np.max(np.abs(sigmayy))
    sigmaxx /= xmax
    sigmayy /= ymax
    sigmaxy /= np.sqrt(xmax * ymax)
    sigmayx /= np.sqrt(xmax * ymax)

    ([u, s, v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs) = compute_ccas(
        sigmaxx, sigmaxy, sigmayx, sigmayy, epsilon=epsilon, verbose=verbose
    )

    # if x_idxs or y_idxs is all false, return_dict has zero entries
    if (not np.any(x_idxs)) or (not np.any(y_idxs)):
        return create_zero_dict(compute_dirns, acts1.shape[1])

    if compute_coefs:
        # also compute full coefficients over all neurons
        x_mask = np.dot(x_idxs.reshape((-1, 1)), x_idxs.reshape((1, -1)))
        y_mask = np.dot(y_idxs.reshape((-1, 1)), y_idxs.reshape((1, -1)))

        return_dict["coef_x"] = u.T
        return_dict["invsqrt_xx"] = invsqrt_xx
        return_dict["full_coef_x"] = np.zeros((numx, numx))
        np.place(return_dict["full_coef_x"], x_mask, return_dict["coef_x"])
        return_dict["full_invsqrt_xx"] = np.zeros((numx, numx))
        np.place(return_dict["full_invsqrt_xx"], x_mask, return_dict["invsqrt_xx"])

        return_dict["coef_y"] = v
        return_dict["invsqrt_yy"] = invsqrt_yy
        return_dict["full_coef_y"] = np.zeros((numy, numy))
        np.place(return_dict["full_coef_y"], y_mask, return_dict["coef_y"])
        return_dict["full_invsqrt_yy"] = np.zeros((numy, numy))
        np.place(return_dict["full_invsqrt_yy"], y_mask, return_dict["invsqrt_yy"])

        # compute means
        neuron_means1 = np.mean(acts1, axis=1, keepdims=True)
        neuron_means2 = np.mean(acts2, axis=1, keepdims=True)
        return_dict["neuron_means1"] = neuron_means1
        return_dict["neuron_means2"] = neuron_means2

    if compute_dirns:
        # orthonormal directions that are CCA directions
        cca_dirns1 = (
            np.dot(
                np.dot(return_dict["full_coef_x"], return_dict["full_invsqrt_xx"]),
                (acts1 - neuron_means1),
            )
            + neuron_means1
        )
        cca_dirns2 = (
            np.dot(
                np.dot(return_dict["full_coef_y"], return_dict["full_invsqrt_yy"]),
                (acts2 - neuron_means2),
            )
            + neuron_means2
        )

    # get rid of trailing zeros in the cca coefficients
    idx1 = sum_threshold(s, threshold)
    idx2 = sum_threshold(s, threshold)

    return_dict["cca_coef1"] = s
    return_dict["cca_coef2"] = s
    return_dict["x_idxs"] = x_idxs
    return_dict["y_idxs"] = y_idxs
    # summary statistics
    return_dict["mean"] = (np.mean(s[:idx1]), np.mean(s[:idx2]))
    return_dict["sum"] = (np.sum(s), np.sum(s))

    if compute_dirns:
        return_dict["cca_dirns1"] = cca_dirns1
        return_dict["cca_dirns2"] = cca_dirns2

    return return_dict


def robust_cca_similarity(acts1, acts2, threshold=0.98, epsilon=1e-6, compute_dirns=True):
    """Calls get_cca_similarity multiple times while adding noise.

    This function is very similar to get_cca_similarity, and can be used if
    get_cca_similarity doesn't converge for some pair of inputs. This function
    adds some noise to the activations to help convergence.

    Args:
              acts1: (num_neurons1, data_points) a 2d numpy array of neurons by
                     datapoints where entry (i,j) is the output of neuron i on
                     datapoint j.
              acts2: (num_neurons2, data_points) same as above, but (potentially)
                     for a different set of neurons. Note that acts1 and acts2
                     can have different numbers of neurons, but must agree on the
                     number of datapoints

              threshold: float between 0, 1 used to get rid of trailing zeros in
                         the cca correlation coefficients to output more accurate
                         summary statistics of correlations.

              epsilon: small float to help stabilize computations

              compute_dirns: boolean value determining whether actual cca
                             directions are computed. (For very large neurons and
                             datasets, may be better to compute these on the fly
                             instead of store in memory.)

    Returns:
              return_dict: A dictionary with outputs from the cca computations.
                           Contains neuron coefficients (combinations of neurons
                           that correspond to cca directions), the cca correlation
                           coefficients (how well aligned directions correlate),
                           x and y idxs (for computing cca directions on the fly
                           if compute_dirns=False), and summary statistics. If
                           compute_dirns=True, the cca directions are also
                           computed.
    """

    for trial in range(num_cca_trials):
        try:
            return_dict = get_cca_similarity(acts1, acts2, threshold, compute_dirns)
        except np.linalg.LinAlgError:
            acts1 = acts1 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon
            acts2 = acts2 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon
            if trial + 1 == num_cca_trials:
                raise

    return return_dict
    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/cca_core.py


def top_k_pca_comps(singular_values, threshold=0.99):
    total_variance = np.sum(singular_values**2)
    explained_variance = (singular_values**2) / total_variance
    cumulative_variance = np.cumsum(explained_variance)
    return np.argmax(cumulative_variance >= threshold * total_variance) + 1


def _svcca_original(acts1, acts2):
    # Copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/tutorials/001_Introduction.ipynb
    # Modification: get_cca_similarity is in the same file.
    # Modification: top-k PCA component selection s.t. explained variance > 0.99 total variance
    # Mean subtract activations
    cacts1 = acts1 - np.mean(acts1, axis=1, keepdims=True)
    cacts2 = acts2 - np.mean(acts2, axis=1, keepdims=True)

    # Perform SVD
    U1, s1, V1 = np.linalg.svd(cacts1, full_matrices=False)
    U2, s2, V2 = np.linalg.svd(cacts2, full_matrices=False)

    # top-k PCA components only
    k1 = top_k_pca_comps(s1)
    k2 = top_k_pca_comps(s2)

    svacts1 = np.dot(s1[:k1] * np.eye(k1), V1[:k1])
    # can also compute as svacts1 = np.dot(U1.T[:20], cacts1)
    svacts2 = np.dot(s2[:k2] * np.eye(k2), V2[:k2])
    # can also compute as svacts1 = np.dot(U2.T[:20], cacts2)

    svcca_results = get_cca_similarity(svacts1, svacts2, epsilon=1e-10, verbose=False)
    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/tutorials/001_Introduction.ipynb
    return np.mean(svcca_results["cca_coef1"])


# Copied from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/pwcca.py
# Modification: get_cca_similarity is in the same file.
def compute_pwcca(acts1, acts2, epsilon=0.0):
    """Computes projection weighting for weighting CCA coefficients

    Args:
         acts1: 2d numpy array, shaped (neurons, num_datapoints)
         acts2: 2d numpy array, shaped (neurons, num_datapoints)

    Returns:
         Original cca coefficient mean and weighted mean

    """
    sresults = get_cca_similarity(
        acts1,
        acts2,
        epsilon=epsilon,
        compute_dirns=False,
        compute_coefs=True,
        verbose=False,
    )
    if np.sum(sresults["x_idxs"]) <= np.sum(sresults["y_idxs"]):
        dirns = (
            np.dot(
                sresults["coef_x"],
                (acts1[sresults["x_idxs"]] - sresults["neuron_means1"][sresults["x_idxs"]]),
            )
            + sresults["neuron_means1"][sresults["x_idxs"]]
        )
        coefs = sresults["cca_coef1"]
        acts = acts1
        idxs = sresults["x_idxs"]
    else:
        dirns = (
            np.dot(
                sresults["coef_y"],
                (acts1[sresults["y_idxs"]] - sresults["neuron_means2"][sresults["y_idxs"]]),
            )
            + sresults["neuron_means2"][sresults["y_idxs"]]
        )
        coefs = sresults["cca_coef2"]
        acts = acts2
        idxs = sresults["y_idxs"]
    P, _ = np.linalg.qr(dirns.T)
    weights = np.sum(np.abs(np.dot(P.T, acts[idxs].T)), axis=1)
    weights = weights / np.sum(weights)

    return np.sum(weights * coefs), weights, coefs
    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/pwcca.py


##################################################################################

from typing import Union  # noqa:e402

import numpy.typing as npt  # noqa:e402
import torch  # noqa:e402

# from repsim.measures.utils import (
#     SHAPE_TYPE,
#     flatten,
#     resize_wh_reps,
#     to_numpy_if_needed,
#     RepresentationalSimilarityMeasure,
# )  # noqa:e402


def svcca(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
) -> float:
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)
    return _svcca_original(R.T, Rp.T)


def pwcca(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
) -> float:
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)
    return compute_pwcca(R.T, Rp.T)[0]


class SVCCA(RepresentationalSimilarityMeasure):
    def __init__(self):
        super().__init__(
            sim_func=svcca,
            larger_is_more_similar=True,
            is_metric=False,
            is_symmetric=True,
            invariant_to_affine=False,
            invariant_to_invertible_linear=False,
            invariant_to_ortho=True,
            invariant_to_permutation=True,
            invariant_to_isotropic_scaling=True,
            invariant_to_translation=True,
        )

    def __call__(self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE) -> float:
        if shape == "nchw":
            # Move spatial dimensions into the sample dimension
            # If not the same spatial dimension, resample via FFT.
            R, Rp = align_spatial_dimensions(R, Rp)
            shape = "nd"

        return self.sim_func(R, Rp, shape)


class PWCCA(RepresentationalSimilarityMeasure):
    def __init__(self):
        super().__init__(
            sim_func=pwcca,
            larger_is_more_similar=True,
            is_metric=False,
            is_symmetric=False,
            invariant_to_affine=False,
            invariant_to_invertible_linear=False,
            invariant_to_ortho=False,
            invariant_to_permutation=False,
            invariant_to_isotropic_scaling=True,
            invariant_to_translation=True,
        )

    def __call__(self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE) -> float:
        if shape == "nchw":
            # Move spatial dimensions into the sample dimension
            # If not the same spatial dimension, resample via FFT.
            R, Rp = align_spatial_dimensions(R, Rp)
            shape = "nd"

        return self.sim_func(R, Rp, shape)

## get rand

In [None]:
def score_rand(num_runs, weight_matrix_np, weight_matrix_2, num_feats, sim_fn, shapereq_bool):
    all_rand_scores = []
    i = 0
    # for i in range(num_runs):
    while i < num_runs:
        try:
            rand_modA_feats = np.random.choice(range(weight_matrix_np.shape[0]), size=num_feats, replace=False).tolist()
            rand_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()

            if shapereq_bool:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats], "nd")
            else:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats])
            all_rand_scores.append(score)
            i += 1
        except:
            continue
    return sum(all_rand_scores) / len(all_rand_scores)

In [None]:
def score_rand_corr(num_runs, weight_matrix_np, weight_matrix_2, num_feats, highest_correlations_indices_AB, sim_fn, shapereq_bool):
    all_rand_scores = []
    i = 0
    # for i in range(num_runs):
    while i < num_runs:
        try:
            rand_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()
            rand_modA_feats = [highest_correlations_indices_AB[index] for index in rand_modB_feats]

            if shapereq_bool:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats], "nd")
            else:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats])
            all_rand_scores.append(score)
            i += 1
        except:
            continue
    # print(sum(all_rand_scores) / len(all_rand_scores))
    # plt.hist(all_rand_scores)
    # plt.show()
    return all_rand_scores

In [None]:
import random
def shuffle_rand(num_runs, weight_matrix_np, weight_matrix_2, num_feats, sim_fn, shapereq_bool):
    all_rand_scores = []
    for i in range(num_runs):
        row_idxs = list(range(num_feats))
        random.shuffle(row_idxs)
        if shapereq_bool:
            score = sim_fn(weight_matrix_np, weight_matrix_2[row_idxs], "nd")
        else:
            score = sim_fn(weight_matrix_np, weight_matrix_2[row_idxs])
        all_rand_scores.append(score)
    # return sum(all_rand_scores) / len(all_rand_scores)
    return all_rand_scores

## plot fns

In [None]:
def plot_svcca_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['svcca_paired'] for i in range(0, 12)]
    unpaired_values = [layer_to_dictscores[i]['svcca_unpaired'] for i in range(0, 12)]

    # Plotting configuration
    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    # Increase figure size
    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
    rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

    # Adding labels, title and custom x-axis tick labels
    ax.set_ylabel('SVCCA')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    # Rotate labels and adjust padding
    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            # if i % 2 == 0:  # Label every other bar above
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12) # , rotation=90
            # else:  # Label every other bar below
            #     ax.text(rect.get_x() + rect.get_width()/2., 0,
            #             f'{height:.3f}',
            #             ha='center', va='top', rotation=90)

    label_bars(rects1)
    label_bars(rects2)

    # Adjust layout to prevent cutting off labels
    plt.tight_layout()

    plt.show()

In [None]:
def plot_rsa_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['rsa_paired'] for i in range(0, 12)]
    unpaired_values = [layer_to_dictscores[i]['rsa_unpaired'] for i in range(0, 12)]

    # Plotting configuration
    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    # Increase figure size
    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
    rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

    # Adding labels, title and custom x-axis tick labels
    ax.set_ylabel('RSA')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    # Rotate labels and adjust padding
    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            # if i % 2 == 0:  # Label every other bar above
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12) # , rotation=90
            # else:  # Label every other bar below
            #     ax.text(rect.get_x() + rect.get_width()/2., 0,
            #             f'{height:.3f}',
            #             ha='center', va='top', rotation=90)

    label_bars(rects1)
    label_bars(rects2)

    # Adjust layout to prevent cutting off labels
    plt.tight_layout()

    plt.show()

In [None]:
def plot_meanCorr_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['mean_actv_corr'] for i in range(0, 12)]
    # unpaired_values = [layer_to_dictscores[i]['svcca_unpaired'] for i in range(0, 12)]

    # Plotting configuration
    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    # Increase figure size
    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
    # rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

    # Adding labels, title and custom x-axis tick labels
    ax.set_ylabel('Corr')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    # Rotate labels and adjust padding
    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            # if i % 2 == 0:  # Label every other bar above
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12) # , rotation=90
            # else:  # Label every other bar below
            #     ax.text(rect.get_x() + rect.get_width()/2., 0,
            #             f'{height:.3f}',
            #             ha='center', va='top', rotation=90)

    label_bars(rects1)
    # label_bars(rects2)

    # Adjust layout to prevent cutting off labels
    plt.tight_layout()

    plt.show()

In [None]:
def plot_meanCorr_filt_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['mean_actv_corr_filt'] for i in range(0, 12)]

    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')

    ax.set_ylabel('Corr')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12)

    label_bars(rects1)
    plt.tight_layout()

    plt.show()

In [None]:
def plot_numFeats_afterFilt_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['num_feat_filt'] for i in range(0, 12)]

    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')

    ax.set_ylabel('Num Feats Kept')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12)

    label_bars(rects1)
    plt.tight_layout()

    plt.show()

In [None]:
# def plot_js_byLayer(layer_to_dictscores):
#     for key, sub_dict in layer_to_dictscores.items():
#         for sub_key, value in sub_dict.items():
#             sub_dict[sub_key] = round(value, 4)

#     layers = [f'L{i}' for i in range(0, 12)]
#     paired_values = [layer_to_dictscores[i]['1-1 jaccard_paired'] for i in range(0, 12)]
#     unpaired_values = [layer_to_dictscores[i]['1-1 jaccard_unpaired'] for i in range(0, 12)]

#     # Plotting configuration
#     x = np.arange(len(layers))  # label locations
#     width = 0.35  # width of the bars

#     # Increase figure size
#     fig, ax = plt.subplots(figsize=(12, 7))  # Slightly increased figure size
#     rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
#     rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

#     # Adding labels, title and custom x-axis tick labels
#     ax.set_ylabel('Jaccard NN')
#     ax.set_title('SAEs comparison by Pythia 70m MLP0 vs 160m MLP Layers')
#     ax.set_xticks(x)
#     ax.set_xticklabels(layers)
#     ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
#     ax.legend()

#     # Label bars with increased font size and different positioning for paired and unpaired
#     def label_bars(rects, is_paired):
#         for rect in rects:
#             height = rect.get_height()
#             label_height = height + 0.05 if is_paired else height + 0.01
#             ax.text(rect.get_x() + rect.get_width()/2., label_height,
#                     f'{height:.3f}',
#                     ha='center', va='bottom', fontsize=9)  # Increased font size to 12

#     label_bars(rects1, True)   # Paired bars
#     label_bars(rects2, False)  # Unpaired bars

#     # Adjust layout to prevent cutting off labels
#     plt.tight_layout()

#     # Increase y-axis limit to accommodate higher labels
#     ax.set_ylim(0, 1.1)  # Increased from 1 to 1.1

#     plt.show()

## interpret fns

In [None]:
def highest_activating_tokens(
    feature_acts,
    feature_idx: int,
    k: int = 10,  # num batch_seq samples
    batch_tokens=None
): # -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = batch_tokens.shape

    # Get the top k largest activations for only targeted feature
    # need to flatten (batch,seq) into batch*seq first because it's ANY batch_seq, even if in same batch or same pos
    flattened_feature_acts = feature_acts[:, :, feature_idx].reshape(-1)

    top_acts_values, top_acts_indices = flattened_feature_acts.topk(k)
    # top_acts_values should be 1D
    # top_acts_indices should be also be 1D. Now, turn it back to 2D
    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values

In [None]:
def store_top_toks(top_acts_indices, top_acts_values, batch_tokens):
    feat_samps = []
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        new_str_token = tokenizer.decode(batch_tokens[batch_idx, seq_idx]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
        feat_samps.append(new_str_token)
    return feat_samps

In [None]:
def find_indices_with_keyword(fList, keyword):
    """
    Find all indices of fList which contain the keyword in the string at those indices.

    Args:
    fList (list of str): List of strings to search within.
    keyword (str): Keyword to search for within the strings of fList.

    Returns:
    list of int: List of indices where the keyword is found within the strings of fList.
    """
    index_list = []
    for index, split_list in enumerate(fList):
        no_space_list = [i.replace(' ', '').lower() for i in split_list]
        for tok in no_space_list:
            if keyword.lower() == tok:
                index_list.append(index)
        # if keyword in no_space_list:
            # index_list.append(index)
    return index_list

In [None]:
from rich import print as rprint
def display_top_sequences(top_acts_indices, top_acts_values, batch_tokens):
    s = ""
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # s += f'{batch_idx}\n'
        s += f'batchID: {batch_idx}, '
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq_start = max(seq_idx - 5, 0)
        seq_end = min(seq_idx + 5, batch_tokens.shape[1])
        seq = ""
        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
        for i in range(seq_start, seq_end):
            # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            if i == seq_idx:
                new_str_token = f"[bold u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        s += f'Act = {value:.2f}, Seq = "{seq}"\n'

    rprint(s)

In [None]:
# def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):
#     feat_samps = []
#     for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
#         # Get the sequence as a string (with some padding on either side of our sequence)
#         seq_start = max(seq_idx - 5, 0)
#         seq_end = min(seq_idx + 5, batch_tokens.shape[1])
#         seq = ""
#         # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
#         for i in range(seq_start, seq_end):
#             # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
#             new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
#             # if i == seq_idx:
#             #     new_str_token = f"[bold u dark_orange]{new_str_token}[/]"
#             seq += new_str_token
#         feat_samps.append(seq)
#     return feat_samps

In [None]:
# def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
#     feat_list = []
#     for feat_ind, top_seqs_lst in enumerate(fList_seqs):
#         for seq in top_seqs_lst:
#             split_list = seq.split(' ')
#             flag = False
#             for tok in split_list:
#                 if keyword.lower() == tok:
#                     feat_list.append(feat_ind)
#                     flag = True
#                     break
#             if flag:
#                 break
#     return feat_list

In [None]:
# def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):
#     feat_samps = []
#     for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
#         # Get the sequence with padding
#         seq_start = max(seq_idx - 2, 0)
#         seq_end = min(seq_idx + 2, batch_tokens.shape[1])
#         seq = ""
#         tokens = []
#         for i in range(seq_start, seq_end):
#             token_id = batch_tokens[batch_idx, i].item()
#             new_str_token = tokenizer.decode([token_id]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
#             tokens.append((new_str_token, i))
#             seq += new_str_token
#         # Store the sequence, tokens, and seq_idx
#         feat_samps.append((seq, tokens, seq_idx))
#     return feat_samps

# def get_word(tokens, idx_in_tokens):
#     # Initialize the word with the token at seq_idx
#     word_tokens = [tokens[idx_in_tokens][0]]
#     # Move backwards to find the start of the word
#     i = idx_in_tokens - 1
#     while i >= 0:
#         token_str, _ = tokens[i]
#         if token_str.startswith(' '):
#             break
#         word_tokens.insert(0, token_str)
#         i -= 1
#     # Move forwards to find the end of the word
#     i = idx_in_tokens + 1
#     while i < len(tokens):
#         token_str, _ = tokens[i]
#         if token_str.startswith(' '):
#             break
#         word_tokens.append(token_str)
#         i += 1
#     # Reconstruct the word and remove any spaces
#     word = ''.join(word_tokens).replace(' ', '')
#     return word


# def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
#     feat_list = []
#     for feat_ind, top_seqs_lst in enumerate(fList_seqs):
#         for seq, tokens, seq_idx in top_seqs_lst:
#             # Find the position of seq_idx in tokens
#             idx_in_tokens = None
#             for i, (token_str, idx) in enumerate(tokens):
#                 if idx == seq_idx:
#                     idx_in_tokens = i
#                     break
#             if idx_in_tokens is None:
#                 continue  # seq_idx not found in tokens

#             # Reconstruct the word containing seq_idx
#             word = get_word(tokens, idx_in_tokens)

#             # Compare the reconstructed word with the keyword
#             if word.lower() == keyword.lower():
#                 feat_list.append(feat_ind)
#                 break  # Proceed to the next feature index
#     return feat_list


In [None]:
def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):
    feat_samps = []
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq_start = max(seq_idx - 2, 0)
        seq_end = min(seq_idx + 2, batch_tokens.shape[1])
        seq = ""
        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
        for i in range(seq_start, seq_end):
            # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            if i == seq_idx:
                topTok = new_str_token
            seq += new_str_token
        feat_samps.append( (seq, topTok) )
    return feat_samps

In [None]:
import pdb

def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
    feat_list = []
    for feat_ind, top_seqs_andToks_lst in enumerate(fList_seqs):
        for top_seqs_andToks in top_seqs_andToks_lst:
            seq = top_seqs_andToks[0]
            topTok = top_seqs_andToks[1].replace(' ', '').lower()
            if keyword.lower() != topTok:
                continue
            split_list = seq.split(' ')
            flag = False
            for word in split_list:
                word = word.replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('\\n','')

                # pdb.set_trace()
                if keyword.lower() == word:
                    feat_list.append(feat_ind)
                    flag = True
                    break
            if flag:
                break
    return feat_list

## get concept space features

In [None]:
def get_mixed_feats(fList_model_B, corr_inds, keywords):
    mixed_modA_feats = []
    mixed_modB_feats = []
    added_modA_feats = set()  # To track which modA feats have been added
    added_modB_feats = set()  # To track which modB feats have been added

    for kw in keywords:
        modB_feats = find_indices_with_keyword(fList_model_B, kw)
        for index in modB_feats:
            modA_feat = corr_inds[index]
            modB_feat = index

            # Check if the feature has already been added to maintain uniqueness
            if modA_feat not in added_modA_feats and modB_feat not in added_modB_feats:
                mixed_modA_feats.append(modA_feat)
                mixed_modB_feats.append(modB_feat)
                added_modA_feats.add(modA_feat)
                added_modB_feats.add(modB_feat)

    print("Unique modA feats: ", len(mixed_modA_feats))
    print("Unique modB feats: ", len(mixed_modB_feats))
    return mixed_modA_feats, mixed_modB_feats

In [None]:
def get_mixed_feats_with_kwList(fList_model_B, corr_inds, keywords):
    mixed_modA_feats = []
    mixed_modB_feats = []
    added_modA_feats = set()  # To track which modA feats have been added
    added_modB_feats = set()  # To track which modB feats have been added
    keywords_to_feats = {kw: 0 for kw in keywords} # kw : count

    for kw in keywords:
        modB_feats = find_indices_with_keyword(fList_model_B, kw)
        for index in modB_feats:
            modA_feat = corr_inds[index]
            modB_feat = index

            # Check if the feature has already been added to maintain uniqueness
            if modA_feat not in added_modA_feats and modB_feat not in added_modB_feats:
                mixed_modA_feats.append(modA_feat)
                mixed_modB_feats.append(modB_feat)
                added_modA_feats.add(modA_feat)
                added_modB_feats.add(modB_feat)
                keywords_to_feats[kw] += 1

    print("Unique modA feats: ", len(mixed_modA_feats))
    print("Unique modB feats: ", len(mixed_modB_feats))
    return mixed_modA_feats, mixed_modB_feats, keywords_to_feats

## get actv fns

In [None]:
# def get_weights_and_acts(name, cfg_dict, layer_id, outputs):
def get_weights_and_acts(name, layer_id, outputs):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    hookpoint = "layers." + str(layer_id)

    sae = Sae.load_from_hub(name, hookpoint=hookpoint, device=device)

    weight_matrix = sae.W_dec.cpu().detach().numpy()

    with torch.inference_mode():
        # reshaped_activations_A = sae.pre_acts(outputs.to("cuda"))
        reshaped_activations = sae.pre_acts(outputs.hidden_states[layer_id].to("cuda"))

    first_dim_reshaped = reshaped_activations.shape[0] * reshaped_activations.shape[1]
    reshaped_activations = reshaped_activations.reshape(first_dim_reshaped, reshaped_activations.shape[-1]).cpu()

    # return weight_matrix_np, reshaped_activations_A, orig
    return weight_matrix, reshaped_activations

In [None]:
def count_zero_columns(tensor):
    # Check if all elements in each column are zero
    zero_columns = np.all(tensor == 0, axis=0)
    # Count True values in the zero_columns array
    zero_cols_indices = np.where(zero_columns)[0]
    return np.sum(zero_columns), zero_cols_indices

## run expm fns

In [None]:
def run_expm(layer_id, outputs, outputs_2):
    layer_to_dictscores = {}

    name = "EleutherAI/sae-pythia-70m-32k"
    cfg_dict = {"expansion_factor": 32, "normalize_decoder": True, "num_latents": 32768, "k": 16, "d_in": 512}
    weight_matrix_np, reshaped_activations_A = get_weights_and_acts(name, cfg_dict, layer_id, outputs)
    # zero_cols_count, zero_cols_indices = count_zero_columns(reshaped_activations_A.cpu().numpy())
    # print("Number of zero columns:", zero_cols_count) #, zero_cols_indices

    name = "EleutherAI/sae-pythia-160m-32k"
    for layerID_2 in range(0, 12): # 0, 12
        dictscores = {}

        # redef
        cfg_dict = {"expansion_factor": 32, "normalize_decoder": True, "num_latents": 32768, "k": 16, "d_in": 768}
        weight_matrix_2, reshaped_activations_B = get_weights_and_acts(name, cfg_dict, layerID_2, outputs_2)


        highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(reshaped_activations_A, reshaped_activations_B)

        num_unq_pairs = len(list(set(highest_correlations_indices_AB)))
        print("% unique: ", num_unq_pairs / len(highest_correlations_indices_AB))

        dictscores["mean_actv_corr"] = sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

        ###########
        # filter

        sorted_feat_counts = Counter(highest_correlations_indices_AB).most_common()
        # kept_modA_feats = [feat_ID for feat_ID, count in sorted_feat_counts if count <= 100]
        kept_modA_feats = [feat_ID for feat_ID, count in sorted_feat_counts if count <= 1]

        filt_corr_ind_A = []
        filt_corr_ind_B = []
        seen = set()
        for ind_B, ind_A in enumerate(highest_correlations_indices_AB):
            if ind_A in kept_modA_feats:
                filt_corr_ind_A.append(ind_A)
                filt_corr_ind_B.append(ind_B)
            elif ind_A not in seen:  # only keep one if it's over count X
                seen.add(ind_A)
                filt_corr_ind_A.append(ind_A)
                filt_corr_ind_B.append(ind_B)
        # num_unq_pairs = len(list(set(filt_corr_ind_A)))
        # print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
        # print("num feats after filt: ", len(filt_corr_ind_A))

        new_highest_correlations_indices_A = []
        new_highest_correlations_indices_B = []
        new_highest_correlations_values = []

        for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
            val = highest_correlations_values_AB[ind_B]
            if val > 0:
                new_highest_correlations_indices_A.append(ind_A)
                new_highest_correlations_indices_B.append(ind_B)
                new_highest_correlations_values.append(val)

        # print("% unique after rmv 0s: ", num_unq_pairs / len(new_highest_correlations_indices_A))
        # print("num feats after rmv 0s: ", len(new_highest_correlations_indices_A))
        dictscores["num_feat_kept"] = len(new_highest_correlations_indices_A)

        dictscores["mean_actv_corr_filt"] = sum(new_highest_correlations_values) / len(new_highest_correlations_values)

        ###########
        # sim tests

        # # num_feats = len(filt_corr_ind_A)
        num_feats = len(new_highest_correlations_indices_A)

        # dictscores["svcca_paired"] = svcca(weight_matrix_np[filt_corr_ind_A], weight_matrix_2[filt_corr_ind_B], "nd")
        dictscores["svcca_paired"] = svcca(weight_matrix_np[new_highest_correlations_indices_A], weight_matrix_2[new_highest_correlations_indices_B], "nd")

        dictscores["svcca_unpaired"] = score_rand(weight_matrix_np, weight_matrix_2, num_feats,
                                                  svcca, shapereq_bool=True)

        dictscores["rsa_paired"] = representational_similarity_analysis(weight_matrix_np[new_highest_correlations_indices_A], weight_matrix_2[new_highest_correlations_indices_B], "nd")
        dictscores["rsa_unpaired"] = score_rand(weight_matrix_np, weight_matrix_2, num_feats,
                                                  representational_similarity_analysis, shapereq_bool=True)

        print("Layer: " + str(layerID_2))
        for key, value in dictscores.items():
            print(key + ": " + str(value))
        print("\n")

        layer_to_dictscores[layerID_2] = dictscores
    return layer_to_dictscores

# load data

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]



In [None]:
from datasets import load_dataset
# dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True)

openwebtext.py:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


In [None]:
batch_size = 100
maxseqlen = 300

def get_next_batch(dataset_iter, batch_size=100):
    batch = []
    for _ in range(batch_size):
        try:
            sample = next(dataset_iter)
            batch.append(sample['text'])
        except StopIteration:
            break
    return batch

dataset_iter = iter(dataset)
batch = get_next_batch(dataset_iter, batch_size)

# Tokenize the batch
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=maxseqlen)

# load models

In [None]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
model_2 = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]

## get LLM actvs

In [None]:
with torch.inference_mode():
    outputs = model(**inputs, output_hidden_states=True)
    outputs_2 = model_2(**inputs, output_hidden_states=True)

# MLP3 vs MLP5

## sae actvs

In [None]:
layer_id = 3
name = "EleutherAI/sae-pythia-70m-32k"
hookpoint = "layers." + str(layer_id)
sae = Sae.load_from_hub(name, hookpoint=hookpoint, device=device)

with torch.inference_mode():
    outputs = model(**inputs, output_hidden_states=True)
    feature_acts_A = sae.pre_acts(outputs.hidden_states[layer_id].to('cuda'))

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

layers.3/cfg.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/134M [00:00<?, ?B/s]



In [None]:
first_dim_reshaped = feature_acts_A.shape[0] * feature_acts_A.shape[1]
reshaped_activations_A = feature_acts_A.reshape(first_dim_reshaped, feature_acts_A.shape[-1]).cpu()

In [None]:
weight_matrix_np = sae.W_dec.cpu().detach().numpy()

In [None]:
layer_id_2 = 5
name = "EleutherAI/sae-pythia-160m-32k"
hookpoint = "layers." + str(layer_id_2)
sae_2 = Sae.load_from_hub(name, hookpoint=hookpoint, device=device)

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

layers.5/cfg.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]



In [None]:
with torch.inference_mode():
    feature_acts_B = sae_2.pre_acts(outputs_2.hidden_states[layer_id_2].to("cuda"))

In [None]:
first_dim_reshaped = feature_acts_B.shape[0] * feature_acts_B.shape[1]
reshaped_activations_B = feature_acts_B.reshape(first_dim_reshaped, feature_acts_B.shape[-1]).cpu()

In [None]:
weight_matrix_2 = sae_2.W_dec.cpu().detach().numpy()

## get labels

In [None]:
# store feature : lst of top strs
fList_model_B = []
samp_m = 5

for feature_idx in range(feature_acts_B.shape[-1]):
    if feature_idx % 5000 == 0:
        print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])
    fList_model_B.append(store_top_toks(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )

Feature:  0
Feature:  5000
Feature:  10000
Feature:  15000
Feature:  20000
Feature:  25000
Feature:  30000


In [None]:
# store feature : lst of top strs
fList_model_A = []
samp_m = 5

for feature_idx in range(feature_acts_A.shape[-1]):
    if feature_idx % 5000 == 0:
        print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx, samp_m, batch_tokens= inputs['input_ids'])
    fList_model_A.append(store_top_toks(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )

Feature:  0
Feature:  5000
Feature:  10000
Feature:  15000
Feature:  20000
Feature:  25000
Feature:  30000


## corr

In [None]:
# """
# `batched_correlation(reshaped_activations_B, reshaped_activations_A)` : highest_correlations_indices_AB contains modA's feats as inds, and modB's feats as vals. Use the list with smaller number of features (cols) as the second arg
# (reshaped_activations_A, reshaped_activations_B): modB is inds, modA is vals of highest_correlations_indices_AB
# """
# corr_inds, corr_vals = batched_correlation(reshaped_activations_A, reshaped_activations_B)

# num_unq_pairs = len(list(set(corr_inds)))
# print("% unique: ", num_unq_pairs / len(corr_inds))

# sum(corr_vals) / len(corr_vals)

% unique:  0.19976806640625


0.6982761558497259

In [None]:
# sorted_feat_counts = Counter(corr_inds).most_common()
# kept_modA_feats = [feat_ID for feat_ID, count in sorted_feat_counts if count == 1]
# len(kept_modA_feats)

4103

# match by semantic subspace- appr 2

APPR2: match by separately finding semantics from each then trying

## time

In [None]:
new_keywords = [
    "day", "night", "week", "month", "year", "hour", "minute", "second", "now", "soon",
    "later", "early", "late", "morning", "evening", "noon", "midnight", "dawn", "dusk", "past",
    "present", "future", "before", "after", "yesterday", "today", "tomorrow", "next", "previous", "soon",
    "fast", "slow", "quick", "moment", "instant", "era", "age", "decade", "century", "millennium",
    "moment", "while", "pause", "wait", "begin", "start", "end", "finish", "stop", "continue",
    "until", "since", "then", "when", "whenever", "always", "never", "forever", "constant", "frequent",
    "occasion", "season", "spring", "summer", "autumn", "fall", "winter", "anniversary", "deadline", "schedule",
    "calendar", "clock", "date", "duration", "interval", "epoch", "generation", "period", "cycle", "timespan",
    "shift", "quarter", "term", "turn", "phase", "lifetime", "century", "minute", "timeline", "delay",
    "prompt", "timely", "recurrent", "daily", "weekly", "monthly", "yearly", "annual", "biweekly", "timeframe"
]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method (no mask, confusing inds)

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.2135102533172497
354


0.48596518441801045

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 1933])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 1658])

In [None]:
len(subset_vals)

1658

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


354

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.5169330000196789

In [None]:
len(filt_corr_ind_A)

354

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5169330000196789

### mask method (didnt work)

In [None]:
# total_features_A = reshaped_activations_A.shape[1]
# total_features_B = reshaped_activations_B.shape[1]

# mask_A = np.zeros(total_features_A, dtype=bool)
# mask_B = np.zeros(total_features_B, dtype=bool)
# mask_A[mixed_modA_feats] = True
# mask_B[mixed_modB_feats] = True

# # Mask the activations by setting non-selected features to zero
# masked_activations_A = reshaped_activations_A.clone()
# masked_activations_B = reshaped_activations_B.clone()
# masked_activations_A[:, ~mask_A] = np.nan
# masked_activations_B[:, ~mask_B] = np.nan

# subset_inds, subset_vals = batched_correlation(masked_activations_A, masked_activations_B)

In [None]:
# len(subset_inds)

32768

In [None]:
# np.count_nonzero(subset_inds==0)

32768

In [None]:
# sorted_feat_counts = Counter(subset_inds).most_common()
# kept_modA_feats = [feat_ID for feat_ID, count in sorted_feat_counts if count == 1]
# len(kept_modA_feats)

196

In [None]:
# filt_corr_ind_A = []
# filt_corr_ind_B = []
# seen = set()
# for ind_B, ind_A in enumerate(subset_inds):
#     # if ind_A == 0:
#     #     continue
#     # if ind_B in mixed_modB_feats and
#     if ind_A in kept_modA_feats:
#         filt_corr_ind_A.append(ind_A)
#         filt_corr_ind_B.append(ind_B)
#     elif ind_A not in seen:  # only keep one if it's over count X
#         seen.add(ind_A)
#         filt_corr_ind_A.append(ind_A)
#         filt_corr_ind_B.append(ind_B)
# num_unq_pairs = len(list(set(filt_corr_ind_A)))
# print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
# num_unq_pairs

% unique:  1.0


1

In [None]:
# paired_score = svcca(weight_matrix_np[filt_corr_ind_A], weight_matrix_2[filt_corr_ind_B], "nd")
# paired_score

0.5381225981550624

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.47279903292655945
Model A Feature:  32471


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.39404481649398804
Model A Feature:  27324


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.6927604079246521
Model A Feature:  27324


Model B Feature:  8


--------------------------------------------------
Correlation: 0.24167844653129578
Model A Feature:  23223


Model B Feature:  16393


--------------------------------------------------
Correlation: 0.7174993753433228
Model A Feature:  27324


Model B Feature:  13


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.47279903292655945
Model A Feature:  32471


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.39404481649398804
Model A Feature:  27324


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.24167844653129578
Model A Feature:  23223


Model B Feature:  16393


--------------------------------------------------
Correlation: 0.4513143301010132
Model A Feature:  2482


Model B Feature:  17


--------------------------------------------------
Correlation: 0.26654374599456787
Model A Feature:  460


Model B Feature:  8211


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.04270151959007941

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## people

In [None]:
new_keywords = [
    "man", "girl", "boy", "kid", "dad", "mom", "son", "sis", "bro",
    "pal", "mate", "boss", "chief", "cop", "guide", "priest", "king",
    "queen", "duke", "lord", "friend", "judge", "clerk", "coach", "team",
    "crew", "staff", "nurse", "doc", "vet", "cook", "maid", "clown",
    "star", "clan", "host", "guest", "peer", "guard", "boss", "spy",
    "fool", "punk", "nerd", "jock", "chief", "folk", "crowd"
]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method (no mask, confusing inds)

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.421259842519685
107


0.3155307187075455

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 646])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 254])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


107

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.4779091023620769

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.4779091023620769

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4478982985019684
Model A Feature:  21818


Model B Feature:  10752


--------------------------------------------------
Correlation: 0.1245780810713768
Model A Feature:  3622


Model B Feature:  513


--------------------------------------------------
Correlation: 0.50855553150177
Model A Feature:  21818


Model B Feature:  17922


--------------------------------------------------
Correlation: 0.09056451916694641
Model A Feature:  2204


Model B Feature:  29699


--------------------------------------------------
Correlation: 0.05578344315290451
Model A Feature:  4059


Model B Feature:  26628


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4478982985019684
Model A Feature:  21818


Model B Feature:  10752


--------------------------------------------------
Correlation: 0.1245780810713768
Model A Feature:  3622


Model B Feature:  513


--------------------------------------------------
Correlation: 0.09056451916694641
Model A Feature:  2204


Model B Feature:  29699


--------------------------------------------------
Correlation: 0.05578344315290451
Model A Feature:  4059


Model B Feature:  26628


--------------------------------------------------
Correlation: 0.09751055389642715
Model A Feature:  289


Model B Feature:  5125


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.4624965353871978

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4478982985019684
Model A Feature:  21818


Model B Feature:  10752


--------------------------------------------------
Correlation: 0.6679359078407288
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------
Correlation: 0.7191584706306458
Model A Feature:  2154


Model B Feature:  14863


--------------------------------------------------
Correlation: 0.2364862710237503
Model A Feature:  22245


Model B Feature:  23572


--------------------------------------------------
Correlation: 0.4013051986694336
Model A Feature:  3297


Model B Feature:  29206


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.0789440502612475

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## science

In [None]:
new_keywords = [
    "cell", "gene", "nerve", "brain", "blood", "bone", "heart", "lung",
    "star", "space", "light", "mass", "force", "wave", "speed", "sound",
    "time", "power", "heat", "cold", "charge", "spark", "flame", "bond",
    "quark", "atom", "ion", "gas", "wind", "ice", "plant", "rock",
    "probe", "test", "fact", "proof", "code", "law", "rule", "graph",
    "scale", "scope", "lens", "ray", "line", "chart", "flux", "phase",
    "shock", "pulse"]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.16195372750642673
189


0.485931963315864

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 1176])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 1167])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


189

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.44688192459267656

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.44688192459267656

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5913411974906921
Model A Feature:  1996


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.4207878112792969
Model A Feature:  15076


Model B Feature:  14339


--------------------------------------------------
Correlation: 0.37127792835235596
Model A Feature:  22963


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.5509901642799377
Model A Feature:  24165


Model B Feature:  24582


--------------------------------------------------
Correlation: 0.17862802743911743
Model A Feature:  10265


Model B Feature:  8199


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5913411974906921
Model A Feature:  1996


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.4207878112792969
Model A Feature:  15076


Model B Feature:  14339


--------------------------------------------------
Correlation: 0.37127792835235596
Model A Feature:  22963


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.5509901642799377
Model A Feature:  24165


Model B Feature:  24582


--------------------------------------------------
Correlation: 0.17862802743911743
Model A Feature:  10265


Model B Feature:  8199


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.6007909362932458

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5913411974906921
Model A Feature:  1996


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.4207878112792969
Model A Feature:  15076


Model B Feature:  14339


--------------------------------------------------
Correlation: 0.37127792835235596
Model A Feature:  22963


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.5509901642799377
Model A Feature:  24165


Model B Feature:  24582


--------------------------------------------------
Correlation: 0.5713392496109009
Model A Feature:  25316


Model B Feature:  12296


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.06371084869717532

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## nature

In [None]:
new_keywords = [
    "tree", "grass", "bush", "plant", "stone", "rock", "cliff", "hill",
    "dirt", "sand", "mud", "wind", "storm", "rain", "cloud", "sun",
    "moon", "star", "leaf", "branch", "twig", "root", "bark", "seed",
    "wave", "tide", "lake", "pond", "creek", "sea", "wood", "field",
    "shore", "snow", "ice", "flame", "fire", "fog", "dew", "hail",
    "sky", "earth", "glade", "cave", "peak", "ridge", "dust", "air",
    "mist", "heat"]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.3744493392070485
85


0.3138826178296547

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 455])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 227])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


85

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.5359066949196167

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5359066949196167

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.11672232300043106
Model A Feature:  8042


Model B Feature:  27648


--------------------------------------------------
Correlation: 0.487680047750473
Model A Feature:  4001


Model B Feature:  10756


--------------------------------------------------
Correlation: 0.1838884800672531
Model A Feature:  15101


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.22668106853961945
Model A Feature:  15101


Model B Feature:  14343


--------------------------------------------------
Correlation: 0.09755470603704453
Model A Feature:  29222


Model B Feature:  1042


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.11672232300043106
Model A Feature:  8042


Model B Feature:  27648


--------------------------------------------------
Correlation: 0.487680047750473
Model A Feature:  4001


Model B Feature:  10756


--------------------------------------------------
Correlation: 0.1838884800672531
Model A Feature:  15101


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.09755470603704453
Model A Feature:  29222


Model B Feature:  1042


--------------------------------------------------
Correlation: 0.12872393429279327
Model A Feature:  25912


Model B Feature:  16407


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5873188540022541

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.487680047750473
Model A Feature:  4001


Model B Feature:  10756


--------------------------------------------------
Correlation: 0.7750003337860107
Model A Feature:  1906


Model B Feature:  19510


--------------------------------------------------
Correlation: 0.8537799715995789
Model A Feature:  8907


Model B Feature:  13880


--------------------------------------------------
Correlation: 0.594294548034668
Model A Feature:  2946


Model B Feature:  15936


--------------------------------------------------
Correlation: 0.809211790561676
Model A Feature:  17935


Model B Feature:  587


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.08967562798095273

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## science (rmv words part of compound & common non-category specific words)

In [None]:
new_keywords = [
    "cell", "gene", "nerve", "brain", "blood", "bone", "heart", "lung",
    "star", "space", "light", "force", "wave", "speed",
    "power", "heat", "cold", "charge", "spark", "flame",
    "quark", "atom", "gas", "ice", "plant", "rock",
    "probe", "proof", "code",
    "scale", "scope", "ray", "line", "chart", "flux", "phase",
    "shock", "pulse"]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.29918032786885246
73


0.37046892284492

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 588])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 244])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


73

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.33796925259405053

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.33796925259405053

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.8664892315864563
Model A Feature:  21238


Model B Feature:  31744


--------------------------------------------------
Correlation: 0.5292837619781494
Model A Feature:  13771


Model B Feature:  17921


--------------------------------------------------
Correlation: 0.4208003878593445
Model A Feature:  15076


Model B Feature:  14339


--------------------------------------------------
Correlation: 0.9693226218223572
Model A Feature:  23701


Model B Feature:  18958


--------------------------------------------------
Correlation: 0.37977468967437744
Model A Feature:  13771


Model B Feature:  15886


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.8664892315864563
Model A Feature:  21238


Model B Feature:  31744


--------------------------------------------------
Correlation: 0.5292837619781494
Model A Feature:  13771


Model B Feature:  17921


--------------------------------------------------
Correlation: 0.4208003878593445
Model A Feature:  15076


Model B Feature:  14339


--------------------------------------------------
Correlation: 0.9693226218223572
Model A Feature:  23701


Model B Feature:  18958


--------------------------------------------------
Correlation: 0.24422234296798706
Model A Feature:  13259


Model B Feature:  5652


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
len(new_highest_correlations_indices_A)

40

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5779577812059175

In [None]:
samp_m = 5
num_samps = 20
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:num_samps],
                                                      new_highest_correlations_indices_B[:num_samps]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.8664892315864563
Model A Feature:  21238


Model B Feature:  31744


--------------------------------------------------
Correlation: 0.5292837619781494
Model A Feature:  13771


Model B Feature:  17921


--------------------------------------------------
Correlation: 0.4208003878593445
Model A Feature:  15076


Model B Feature:  14339


--------------------------------------------------
Correlation: 0.9693226218223572
Model A Feature:  23701


Model B Feature:  18958


--------------------------------------------------
Correlation: 0.24422234296798706
Model A Feature:  13259


Model B Feature:  5652


--------------------------------------------------
Correlation: 0.9168168902397156
Model A Feature:  11098


Model B Feature:  2081


--------------------------------------------------
Correlation: 0.8931164741516113
Model A Feature:  10217


Model B Feature:  5685


--------------------------------------------------
Correlation: 0.8537800312042236
Model A Feature:  8907


Model B Feature:  13880


--------------------------------------------------
Correlation: 0.5022938847541809
Model A Feature:  29895


Model B Feature:  75


--------------------------------------------------
Correlation: 0.26780441403388977
Model A Feature:  25476


Model B Feature:  6245


--------------------------------------------------
Correlation: 0.678051769733429
Model A Feature:  20120


Model B Feature:  30823


--------------------------------------------------
Correlation: 0.6083025336265564
Model A Feature:  18112


Model B Feature:  17527


--------------------------------------------------
Correlation: 0.7381497621536255
Model A Feature:  25912


Model B Feature:  1145


--------------------------------------------------
Correlation: 0.47376367449760437
Model A Feature:  2595


Model B Feature:  17017


--------------------------------------------------
Correlation: 0.7412620782852173
Model A Feature:  4050


Model B Feature:  26767


--------------------------------------------------
Correlation: 0.7796500325202942
Model A Feature:  5984


Model B Feature:  9873


--------------------------------------------------
Correlation: 0.8067306280136108
Model A Feature:  8471


Model B Feature:  17551


--------------------------------------------------
Correlation: 0.6245339512825012
Model A Feature:  21227


Model B Feature:  25237


--------------------------------------------------
Correlation: 0.3493461012840271
Model A Feature:  16526


Model B Feature:  151


--------------------------------------------------
Correlation: 0.5462920665740967
Model A Feature:  4977


Model B Feature:  7321


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.08336819175818187

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## time (rmv words)

In [None]:
new_keywords = [
    "day", "night", "week", "month", "year", "hour", "minute", "second", "now", "soon",
    "later", "early", "late", "morning", "evening", "noon", "midnight", "dawn", "dusk", "past",
    "present", "future", "before", "after", "yesterday", "today", "tomorrow", "next", "previous", "soon",
    "fast", "slow", "quick", "moment", "instant", "era", "age", "decade", "century", "millennium",
    "moment", "pause", "wait", "begin", "start", "end", "finish", "stop", "continue",
    "until", "since", "then", "when", "whenever", "always", "never", "forever", "constant", "frequent",
    "occasion", "season", "spring", "summer", "autumn", "fall", "winter", "anniversary", "deadline", "schedule",
    "calendar", "clock", "date", "duration", "interval", "epoch", "generation", "period", "cycle", "timespan",
    "shift", "quarter", "term", "turn", "phase", "lifetime", "century", "minute", "timeline", "delay",
    "prompt", "timely", "recurrent", "daily", "weekly", "monthly", "yearly", "annual", "biweekly", "timeframe"
]
len(new_keywords)

99

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.2020460358056266
316


0.48309437328678034

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 1812])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 1564])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


316

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.5040594779090436

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5040594779090436

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4727990925312042
Model A Feature:  32471


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.39404481649398804
Model A Feature:  27324


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.6927604079246521
Model A Feature:  27324


Model B Feature:  8


--------------------------------------------------
Correlation: 0.7174993753433228
Model A Feature:  27324


Model B Feature:  13


--------------------------------------------------
Correlation: 0.4377692937850952
Model A Feature:  3715


Model B Feature:  17


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4727990925312042
Model A Feature:  32471


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.39404481649398804
Model A Feature:  27324


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.4377692937850952
Model A Feature:  3715


Model B Feature:  17


--------------------------------------------------
Correlation: 0.26654374599456787
Model A Feature:  460


Model B Feature:  8211


--------------------------------------------------
Correlation: 0.8091992735862732
Model A Feature:  15325


Model B Feature:  28


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.643185238988271

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4727990925312042
Model A Feature:  32471


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.39404481649398804
Model A Feature:  27324


Model B Feature:  8197


--------------------------------------------------
Correlation: 0.4377692937850952
Model A Feature:  3715


Model B Feature:  17


--------------------------------------------------
Correlation: 0.26654374599456787
Model A Feature:  460


Model B Feature:  8211


--------------------------------------------------
Correlation: 0.8091992735862732
Model A Feature:  15325


Model B Feature:  28


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.0407672619568654

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## colors

In [None]:
new_keywords = ["red", "blue", "yellow", "green", "brown", "purple", "white", "black", "orange"
]
len(new_keywords)

9

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.34285714285714286
36


0.3117880554426284

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 210])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 105])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


36

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.3735588939061499

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.3735588939061499

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.24818731844425201
Model A Feature:  30108


Model B Feature:  25608


--------------------------------------------------
Correlation: 0.3755578398704529
Model A Feature:  5407


Model B Feature:  6664


--------------------------------------------------
Correlation: 0.20828212797641754
Model A Feature:  15831


Model B Feature:  28697


--------------------------------------------------
Correlation: 0.5495645999908447
Model A Feature:  5407


Model B Feature:  31257


--------------------------------------------------
Correlation: 0.17067460715770721
Model A Feature:  20752


Model B Feature:  24612


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.24818731844425201
Model A Feature:  30108


Model B Feature:  25608


--------------------------------------------------
Correlation: 0.3755578398704529
Model A Feature:  5407


Model B Feature:  6664


--------------------------------------------------
Correlation: 0.20828212797641754
Model A Feature:  15831


Model B Feature:  28697


--------------------------------------------------
Correlation: 0.17067460715770721
Model A Feature:  20752


Model B Feature:  24612


--------------------------------------------------
Correlation: 0.7058087587356567
Model A Feature:  11284


Model B Feature:  27174


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7407997373330671

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.24818731844425201
Model A Feature:  30108


Model B Feature:  25608


--------------------------------------------------
Correlation: 0.3755578398704529
Model A Feature:  5407


Model B Feature:  6664


--------------------------------------------------
Correlation: 0.20828212797641754
Model A Feature:  15831


Model B Feature:  28697


--------------------------------------------------
Correlation: 0.7058087587356567
Model A Feature:  11284


Model B Feature:  27174


--------------------------------------------------
Correlation: 0.21376295387744904
Model A Feature:  23858


Model B Feature:  22577


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.1447330986994762

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## famous names

In [None]:
new_keywords = [
    "Einstein", "Newton", "Darwin", "Curie", "Galileo", "Tesla", "Hawking", "Feynman", "Pasteur", "Mendel",
    "Shakespeare", "Dickens", "Hemingway", "Austen", "Orwell", "Tolkien", "Faulkner", "Poe", "Joyce", "Steinbeck",
    "Lincoln", "Washington", "Roosevelt", "Churchill", "Gandhi", "Mandela", "Luther", "Kennedy", "Napoleon", "Cromwell",
    "Picasso", "Rembrandt", "Michelangelo", "Da Vinci", "Van Gogh", "Monet", "Dali", "Matisse", "Warhol", "Pollock",
    "Beethoven", "Mozart", "Bach", "Chopin", "Wagner", "Tchaikovsky", "Stravinsky", "Vivaldi", "Verdi", "Debussy",
    "Jobs", "Gates", "Musk", "Zuckerberg", "Buffett", "Bezos", "Branson", "Disney", "Ford", "Rockefeller",
    "Obama", "Clinton", "Reagan", "Thatcher", "Putin", "Merkel", "Castro", "Hitler", "Mao", "Stalin",
    "Freud", "Jung", "Pavlov", "Skinner", "Piaget", "Maslow", "Erikson", "Rogers", "Chomsky", "Bandura",
    "Ali", "Jordan", "Pele", "Ronaldo", "Messi", "Federer", "Williams", "Bolt", "Phelps", "Brady",
    "Chaplin", "Hitchcock", "Kubrick", "Spielberg", "Scorsese", "Tarantino", "Lucas", "Cameron", "Coppola", "Eastwood"
]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.3968253968253968
50


0.27954385553797084

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 278])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 126])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


50

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.43119215198803607

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.43119215198803607

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.16730208694934845
Model A Feature:  9641


Model B Feature:  19459


--------------------------------------------------
Correlation: 0.14474479854106903
Model A Feature:  6310


Model B Feature:  26634


--------------------------------------------------
Correlation: 0.33182284235954285
Model A Feature:  9424


Model B Feature:  14865


--------------------------------------------------
Correlation: 0.1616276502609253
Model A Feature:  29656


Model B Feature:  21524


--------------------------------------------------
Correlation: 0.1344037652015686
Model A Feature:  2055


Model B Feature:  17444


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.16730208694934845
Model A Feature:  9641


Model B Feature:  19459


--------------------------------------------------
Correlation: 0.14474479854106903
Model A Feature:  6310


Model B Feature:  26634


--------------------------------------------------
Correlation: 0.33182284235954285
Model A Feature:  9424


Model B Feature:  14865


--------------------------------------------------
Correlation: 0.1616276502609253
Model A Feature:  29656


Model B Feature:  21524


--------------------------------------------------
Correlation: 0.1344037652015686
Model A Feature:  2055


Model B Feature:  17444


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.6584466950133537

In [None]:
paired_rsa = representational_similarity_analysis(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_rsa

0.2665452833150212

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.33182284235954285
Model A Feature:  9424


Model B Feature:  14865


--------------------------------------------------
Correlation: 0.4915129244327545
Model A Feature:  30826


Model B Feature:  15416


--------------------------------------------------
Correlation: 0.38749727606773376
Model A Feature:  10292


Model B Feature:  20046


--------------------------------------------------
Correlation: 0.29953935742378235
Model A Feature:  18426


Model B Feature:  13908


--------------------------------------------------
Correlation: 0.5874597430229187
Model A Feature:  11815


Model B Feature:  29274


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.12955559591322904

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

In [None]:
all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          representational_similarity_analysis, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

-0.001906257846888794

In [None]:
np.mean(np.array(all_rand_scores) >= paired_rsa)

0.0

## random keywords

In [None]:
new_keywords = [
    "apple", "bicycle", "cloud", "dog", "elephant", "fountain", "guitar", "honey",
    "iceberg", "jelly", "kangaroo", "laptop", "mountain", "notebook", "ocean",
    "piano", "quartz", "river", "satellite", "tiger", "umbrella", "volcano",
    "whale", "xylophone", "yogurt", "zebra", "balloon", "candle", "desert",
    "engine", "forest", "glove", "hat", "insect", "jungle", "key", "lamp",
    "microscope", "nest", "octopus", "penguin", "quill", "robot", "sandwich",
    "tree", "unicorn", "vase", "window", "yarn", "zipper"
]


In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword(fList_model_B, kw)
    modA_feats = find_indices_with_keyword(fList_model_A, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.26136363636363635
23


0.3063073566352779

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 100])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 88])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


23

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.2524924210624128

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.2524924210624128

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.1712571382522583
Model A Feature:  24540


Model B Feature:  18950


--------------------------------------------------
Correlation: 0.14045624434947968
Model A Feature:  1234


Model B Feature:  4631


--------------------------------------------------
Correlation: 0.1605885773897171
Model A Feature:  1234


Model B Feature:  14876


--------------------------------------------------
Correlation: 0.1736195832490921
Model A Feature:  31050


Model B Feature:  18974


--------------------------------------------------
Correlation: 0.16785818338394165
Model A Feature:  29817


Model B Feature:  9251


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.1712571382522583
Model A Feature:  24540


Model B Feature:  18950


--------------------------------------------------
Correlation: 0.14045624434947968
Model A Feature:  1234


Model B Feature:  4631


--------------------------------------------------
Correlation: 0.1736195832490921
Model A Feature:  31050


Model B Feature:  18974


--------------------------------------------------
Correlation: 0.16785818338394165
Model A Feature:  29817


Model B Feature:  9251


--------------------------------------------------
Correlation: 0.18755601346492767
Model A Feature:  19504


Model B Feature:  18469


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.24090043416491444

In [None]:
paired_rsa = representational_similarity_analysis(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_rsa

0.19487734487734493

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.30862289667129517
Model A Feature:  18351


Model B Feature:  7748


--------------------------------------------------
Correlation: 0.2702399492263794
Model A Feature:  6370


Model B Feature:  10330


--------------------------------------------------
Correlation: 0.2142949104309082
Model A Feature:  10760


Model B Feature:  10364


--------------------------------------------------
Correlation: 0.20555566251277924
Model A Feature:  12359


Model B Feature:  23688


--------------------------------------------------
Correlation: 0.2462807446718216
Model A Feature:  26638


Model B Feature:  11942


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.1742823969784607

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.29

In [None]:
all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          representational_similarity_analysis, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

-0.0061064896532202284

In [None]:
np.mean(np.array(all_rand_scores) >= paired_rsa)

0.0

# check if kw is compound

## get seq list

In [None]:
# store feature : lst of top strs
fList_model_B_seqs = []
samp_m = 5

for feature_idx in range(feature_acts_B.shape[-1]):
# for feature_idx in range(5):
    if feature_idx % 5000 == 0:
        print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])
    fList_model_B_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )

Feature:  0
Feature:  5000
Feature:  10000
Feature:  15000
Feature:  20000
Feature:  25000
Feature:  30000


In [None]:
# store feature : lst of top strs
fList_model_A_seqs = []
samp_m = 5

for feature_idx in range(feature_acts_A.shape[-1]):
    if feature_idx % 5000 == 0:
        print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx, samp_m, batch_tokens= inputs['input_ids'])
    fList_model_A_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )

Feature:  0
Feature:  5000
Feature:  10000
Feature:  15000
Feature:  20000
Feature:  25000
Feature:  30000


In [None]:
# import pdb

# def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
#     feat_list = []
#     for feat_ind, top_seqs_andToks_lst in enumerate(fList_seqs):
#         for top_seqs_andToks in top_seqs_andToks_lst:
#             seq = top_seqs_andToks[0]
#             topTok = top_seqs_andToks[1].replace(' ', '').lower()
#             if keyword.lower() != topTok:
#                 continue
#             split_list = seq.split(' ')
#             flag = False
#             for word in split_list:
#                 word = word.replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('\\n','')

#                 # pdb.set_trace()
#                 if keyword.lower() == word:
#                     feat_list.append(feat_ind)
#                     flag = True
#                     break
#             if flag:
#                 break
#     return feat_list

In [None]:
keyword = 'man'
modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, keyword)
# index_list = []
# for index, top_seqs_lst in enumerate(fList_model_B_seqs):
#     for seq in top_seqs_lst:
#         split_list = seq.split(' ')
#         for tok in split_list:
#             if keyword.lower() == tok:
#                 index_list.append(index)

In [None]:
modB_feats

[3998, 5125, 8855, 19286, 19325, 26791, 27254]

In [None]:
samp_m = 5
feature_idx_B = modB_feats[0]
print('Model B Feature: ', feature_idx_B)
ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

Model B Feature:  3998


## people

In [None]:
new_keywords = [
    "man", "girl", "boy", "kid", "dad", "mom", "son", "sis", "bro",
    "pal", "mate", "boss", "chief", "cop", "guide", "priest", "king",
    "queen", "duke", "lord", "friend", "judge", "clerk", "coach", "team",
    "crew", "staff", "nurse", "doc", "vet", "cook", "maid", "clown",
    "star", "clan", "host", "guest", "peer", "guard", "boss", "spy",
    "fool", "punk", "nerd", "jock", "chief", "folk", "crowd"
]

In [None]:
# modB_feats = find_indices_with_keyword_bySeqs(fList_model_B, kw)

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method (no mask, confusing inds)

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.3445692883895131
92


0.3585776478219568

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 352])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 267])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


92

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.5756977795306804

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5756977795306804

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.3534761667251587
Model A Feature:  3297


Model B Feature:  11778


--------------------------------------------------
Correlation: 0.09550344198942184
Model A Feature:  16194


Model B Feature:  5125


--------------------------------------------------
Correlation: 0.08233723044395447
Model A Feature:  15230


Model B Feature:  4615


--------------------------------------------------
Correlation: 0.153561532497406
Model A Feature:  11458


Model B Feature:  11272


--------------------------------------------------
Correlation: 0.16526053845882416
Model A Feature:  3297


Model B Feature:  17930


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.3534761667251587
Model A Feature:  3297


Model B Feature:  11778


--------------------------------------------------
Correlation: 0.09550344198942184
Model A Feature:  16194


Model B Feature:  5125


--------------------------------------------------
Correlation: 0.08233723044395447
Model A Feature:  15230


Model B Feature:  4615


--------------------------------------------------
Correlation: 0.153561532497406
Model A Feature:  11458


Model B Feature:  11272


--------------------------------------------------
Correlation: 0.6679359078407288
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.6318843335104746

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.3534761667251587
Model A Feature:  3297


Model B Feature:  11778


--------------------------------------------------
Correlation: 0.6679359078407288
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------
Correlation: 0.7191584706306458
Model A Feature:  2154


Model B Feature:  14863


--------------------------------------------------
Correlation: 0.49463969469070435
Model A Feature:  13268


Model B Feature:  8722


--------------------------------------------------
Correlation: 0.23486579954624176
Model A Feature:  4459


Model B Feature:  1561


--------------------------------------------------


### debug

In [None]:
relevant_keywords = []
for keyword in new_keywords:
    top_seqs_lst = fList_model_A_seqs[4459]
    for seq in top_seqs_lst:
        split_list = seq.split(' ')
        flag = False
        for tok in split_list:
            if keyword.lower() == tok:
                relevant_keywords.append(seq)
                flag = True
                break

In [None]:
relevant_keywords

['man 1946']

1946 Eddie Bockman 1946 Joe Medwick**"

### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.07398644511314345

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

## people (fix false pos compound words)

In [None]:
new_keywords = [
    "man", "girl", "boy", "kid", "dad", "mom", "son", "sis", "bro",
    "pal", "mate", "boss", "chief", "cop", "guide", "priest", "king",
    "queen", "duke", "lord", "friend", "judge", "clerk", "coach", "team",
    "crew", "staff", "nurse", "doc", "vet", "cook", "maid", "clown",
    "star", "clan", "host", "guest", "peer", "guard", "boss", "spy",
    "fool", "punk", "nerd", "jock", "chief", "folk", "crowd"
]

In [None]:
# modB_feats = find_indices_with_keyword_bySeqs(fList_model_B, kw)

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### old method (no mask, confusing inds)

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.43037974683544306
34


0.3704556228145014

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 128])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 79])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


34

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.35565525530787634

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.35565525530787634

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.09550388157367706
Model A Feature:  16194


Model B Feature:  5125


--------------------------------------------------
Correlation: 0.08233728259801865
Model A Feature:  15230


Model B Feature:  4615


--------------------------------------------------
Correlation: 0.6679355502128601
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------
Correlation: 0.14263388514518738
Model A Feature:  12551


Model B Feature:  13834


--------------------------------------------------
Correlation: 0.7191584706306458
Model A Feature:  2154


Model B Feature:  14863


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.09550388157367706
Model A Feature:  16194


Model B Feature:  5125


--------------------------------------------------
Correlation: 0.08233728259801865
Model A Feature:  15230


Model B Feature:  4615


--------------------------------------------------
Correlation: 0.6679355502128601
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------
Correlation: 0.14263388514518738
Model A Feature:  12551


Model B Feature:  13834


--------------------------------------------------
Correlation: 0.7191584706306458
Model A Feature:  2154


Model B Feature:  14863


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.3209988080154914

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.6679355502128601
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------
Correlation: 0.7191584706306458
Model A Feature:  2154


Model B Feature:  14863


--------------------------------------------------
Correlation: 0.6518351435661316
Model A Feature:  9620


Model B Feature:  8734


--------------------------------------------------
Correlation: 0.677170991897583
Model A Feature:  1555


Model B Feature:  1054


--------------------------------------------------
Correlation: 0.3713623285293579
Model A Feature:  3297


Model B Feature:  14883


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.1310768664756002

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.07

## animals

In [None]:
new_keywords = [
    "dog", "cat", "rat", "bat", "pig", "cow", "fox", "wolf", "ram", "eel",
    "ant", "bee", "bug", "cub", "kit", "fawn", "calf", "colt", "foal",
    "hen", "duck", "goat", "bird", "crow", "fish", "frog", "deer", "worm",
    "moth", "gnat", "clam", "crab", "shrimp", "whale", "shark", "squid",
    "pup", "joey", "owl", "hare", "seal", "mule", "toad", "swan", "sow",
    "bull", "stag", "buck", "boar", "kite"
]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### run

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.375
6


0.30792295024730265

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 19])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 16])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


6

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.3179610447761239

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.3179610447761239

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.7892470359802246
Model A Feature:  23765


Model B Feature:  24961


--------------------------------------------------
Correlation: 0.1722717583179474
Model A Feature:  26638


Model B Feature:  2462


--------------------------------------------------
Correlation: 0.08801285922527313
Model A Feature:  23765


Model B Feature:  19331


--------------------------------------------------
Correlation: 0.24628061056137085
Model A Feature:  26638


Model B Feature:  11942


--------------------------------------------------
Correlation: 0.05735268071293831
Model A Feature:  26638


Model B Feature:  29288


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.7892470359802246
Model A Feature:  23765


Model B Feature:  24961


--------------------------------------------------
Correlation: 0.1722717583179474
Model A Feature:  26638


Model B Feature:  2462


--------------------------------------------------
Correlation: 0.054895468056201935
Model A Feature:  2786


Model B Feature:  4360


--------------------------------------------------
Correlation: 0.40762510895729065
Model A Feature:  27893


Model B Feature:  3947


--------------------------------------------------
Correlation: 0.11722975224256516
Model A Feature:  16811


Model B Feature:  11611


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
len(new_highest_correlations_indices_B)

3

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.9917093445680989

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.7892470359802246
Model A Feature:  23765


Model B Feature:  24961


--------------------------------------------------
Correlation: 0.40762510895729065
Model A Feature:  27893


Model B Feature:  3947


--------------------------------------------------
Correlation: 0.8933273553848267
Model A Feature:  32326


Model B Feature:  18167


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.37605906152196616

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.0

In [None]:
paired_svcca

0.9917093445680989

## colors

In [None]:
new_keywords = ["red", "blue", "yellow", "green", "brown", "purple", "white", "black", "orange"
]

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

### run

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.32
8


0.2371883825957775

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 28])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 25])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


8

In [None]:
X_subset = weight_matrix_np[mixed_modA_feats]
Y_subset = weight_matrix_2[mixed_modB_feats]

paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
paired_svcca

0.6574592919848281

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.6574592919848281

### interpret

In [None]:
samp_m = 5
for subset_feature_idx_B, subset_feature_idx_A in enumerate(subset_inds[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.15349355340003967
Model A Feature:  30108


Model B Feature:  31617


--------------------------------------------------
Correlation: 0.2481871247291565
Model A Feature:  30108


Model B Feature:  25608


--------------------------------------------------
Correlation: 0.0952630415558815
Model A Feature:  9096


Model B Feature:  28942


--------------------------------------------------
Correlation: 0.26997533440589905
Model A Feature:  30108


Model B Feature:  16535


--------------------------------------------------
Correlation: 0.07300806790590286
Model A Feature:  30108


Model B Feature:  31257


--------------------------------------------------


In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:samp_m], filt_corr_ind_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.15349355340003967
Model A Feature:  30108


Model B Feature:  31617


--------------------------------------------------
Correlation: 0.0952630415558815
Model A Feature:  9096


Model B Feature:  28942


--------------------------------------------------
Correlation: 0.7293462753295898
Model A Feature:  13353


Model B Feature:  11434


--------------------------------------------------
Correlation: 0.21376293897628784
Model A Feature:  23858


Model B Feature:  22577


--------------------------------------------------
Correlation: 0.7575472593307495
Model A Feature:  27337


Model B Feature:  27574


--------------------------------------------------


### filter out low corr

In [None]:
new_highest_correlations_indices_A = []
new_highest_correlations_indices_B = []
new_highest_correlations_values = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        new_highest_correlations_indices_A.append(ind_A)
        new_highest_correlations_indices_B.append(ind_B)
        new_highest_correlations_values.append(val)

In [None]:
len(new_highest_correlations_indices_B)

4

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in new_highest_correlations_indices_A]
original_B_indices = [mixed_modB_feats[i] for i in new_highest_correlations_indices_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7040771936056542

In [None]:
samp_m = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(new_highest_correlations_indices_A[:samp_m],
                                                      new_highest_correlations_indices_B[:samp_m]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.7293462753295898
Model A Feature:  13353


Model B Feature:  11434


--------------------------------------------------
Correlation: 0.21376293897628784
Model A Feature:  23858


Model B Feature:  22577


--------------------------------------------------
Correlation: 0.7575472593307495
Model A Feature:  27337


Model B Feature:  27574


--------------------------------------------------
Correlation: 0.2196464091539383
Model A Feature:  3457


Model B Feature:  31046


--------------------------------------------------


### compare to rand

In [None]:
X_subset = weight_matrix_np[filt_corr_ind_A]
Y_subset = weight_matrix_2[filt_corr_ind_B]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
sum(all_rand_scores) / len(all_rand_scores)

0.3000725956813447

In [None]:
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.05

In [None]:
paired_svcca

0.7040771936056542