**bold text**

# setup

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
%%capture
!pip install git+https://github.com/EleutherAI/sae.git

In [3]:
# you should load this before cloning repo files
# from .config import SaeConfig
# from .utils import decoder_impl

from sae.config import SaeConfig
from sae.utils import decoder_impl
from sae import Sae

Triton not installed, using eager implementation of SAE decoder.


In [4]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import json
from fnmatch import fnmatch
from pathlib import Path
from typing import NamedTuple, Optional, Callable, Union, List, Tuple
# from jaxtyping import Float, Int

import einops
import torch
from torch import Tensor, nn
from huggingface_hub import snapshot_download
from natsort import natsorted
from safetensors.torch import load_model, save_model

device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
from collections import Counter
import pandas as pd
from IPython.display import display

## corr fns

In [6]:
def batched_correlation(reshaped_activations_A, reshaped_activations_B, batch_size=100):
    # Ensure tensors are on GPU
    if torch.cuda.is_available():
        reshaped_activations_A = reshaped_activations_A.to('cuda')
        reshaped_activations_B = reshaped_activations_B.to('cuda')

    # Normalize columns of A
    mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)
    std_A = reshaped_activations_A.std(dim=0, keepdim=True)
    normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero

    # Normalize columns of B
    mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)
    std_B = reshaped_activations_B.std(dim=0, keepdim=True)
    normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero

    num_batches = (normalized_B.shape[1] + batch_size - 1) // batch_size
    max_values = []
    max_indices = []

    for batch in range(num_batches):
        start = batch * batch_size
        end = min(start + batch_size, normalized_B.shape[1])
        batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, start:end]) / normalized_A.shape[0]
        max_val, max_idx = batch_corr_matrix.max(dim=0)
        max_values.append(max_val)
        max_indices.append(max_idx)

        del batch_corr_matrix
        torch.cuda.empty_cache()

    corr_inds = torch.cat(max_indices).detach().cpu().numpy()
    corr_vals = torch.cat(max_values).detach().cpu().numpy()
    return corr_inds, corr_vals

In [7]:
def filter_corr_pairs(mixed_modA_feats, mixed_modB_feats, kept_modA_feats):
    filt_corr_ind_A = []
    filt_corr_ind_B = []
    seen = set()
    for ind_A, ind_B in zip(mixed_modA_feats, mixed_modB_feats):
        if ind_A in kept_modA_feats:
            filt_corr_ind_A.append(ind_A)
            filt_corr_ind_B.append(ind_B)
        elif ind_A not in seen:  # only keep one if it's over count X
            seen.add(ind_A)
            filt_corr_ind_A.append(ind_A)
            filt_corr_ind_B.append(ind_B)
    num_unq_pairs = len(list(set(filt_corr_ind_A)))
    print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
    print("num 1-1 feats after filt: ", num_unq_pairs )
    return filt_corr_ind_A, filt_corr_ind_B, num_unq_pairs

In [8]:
def get_new_mean_corr(modA_feats, modB_feats, corr_vals):
    new_vals = []
    seen = set()
    for ind_A, ind_B in zip(modA_feats, modB_feats):
        if ind_A not in seen:
            seen.add(ind_A)
            val = corr_vals[ind_B]
            new_vals.append(val)
    new_mean_corr = sum(new_vals) / len(new_vals)
    # print(new_mean_corr)
    return new_mean_corr

## sim fns

In [9]:
import functools
from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np
import numpy.typing as npt
import torch


def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        return x if isinstance(x, np.ndarray) else x.numpy()

    return list(map(convert, args))


def to_torch_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[torch.Tensor]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
        return x if isinstance(x, torch.Tensor) else torch.from_numpy(x)

    return list(map(convert, args))


def adjust_dimensionality(
    R: npt.NDArray, Rp: npt.NDArray, strategy="zero_pad"
) -> Tuple[npt.NDArray, npt.NDArray]:
    D = R.shape[1]
    Dp = Rp.shape[1]
    if strategy == "zero_pad":
        if D - Dp == 0:
            return R, Rp
        elif D - Dp > 0:
            return R, np.concatenate((Rp, np.zeros((Rp.shape[0], D - Dp))), axis=1)
        else:
            return np.concatenate((R, np.zeros((R.shape[0], Dp - D))), axis=1), Rp
    else:
        raise NotImplementedError()


def center_columns(R: npt.NDArray) -> npt.NDArray:
    return R - R.mean(axis=0)[None, :]


def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord="fro")


def sim_random_baseline(
    rep1: torch.Tensor, rep2: torch.Tensor, sim_func: Callable, n_permutations: int = 10
) -> Dict[str, Any]:
    torch.manual_seed(1234)
    scores = []
    for _ in range(n_permutations):
        perm = torch.randperm(rep1.size(0))

        score = sim_func(rep1[perm, :], rep2)
        score = score if isinstance(score, float) else score["score"]

        scores.append(score)

    return {"baseline_scores": np.array(scores)}


class Pipeline:
    def __init__(
        self,
        preprocess_funcs: List[Callable[[npt.NDArray], npt.NDArray]],
        similarity_func: Callable[[npt.NDArray, npt.NDArray], Dict[str, Any]],
    ) -> None:
        self.preprocess_funcs = preprocess_funcs
        self.similarity_func = similarity_func

    def __call__(self, R: npt.NDArray, Rp: npt.NDArray) -> Dict[str, Any]:
        for preprocess_func in self.preprocess_funcs:
            R = preprocess_func(R)
            Rp = preprocess_func(Rp)
        return self.similarity_func(R, Rp)

    def __str__(self) -> str:
        def func_name(func: Callable) -> str:
            return (
                func.__name__
                if not isinstance(func, functools.partial)
                else func.func.__name__
            )

        def partial_keywords(func: Callable) -> str:
            if not isinstance(func, functools.partial):
                return ""
            else:
                return str(func.keywords)

        return (
            "Pipeline("
            + (
                "+".join(map(func_name, self.preprocess_funcs))
                + "+"
                + func_name(self.similarity_func)
                + partial_keywords(self.similarity_func)
            )
            + ")"
        )

In [10]:
from typing import List, Set, Union

import numpy as np
import numpy.typing as npt
import sklearn.neighbors
import torch

# from llmcomp.measures.utils import to_numpy_if_needed


def _jac_sim_i(idx_R: Set[int], idx_Rp: Set[int]) -> float:
    return len(idx_R.intersection(idx_Rp)) / len(idx_R.union(idx_Rp))


def jaccard_similarity(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    k: int = 10,
    inner: str = "cosine",
    n_jobs: int = 8,
) -> float:
    R, Rp = to_numpy_if_needed(R, Rp)

    indices_R = nn_array_to_setlist(top_k_neighbors(R, k, inner, n_jobs))
    indices_Rp = nn_array_to_setlist(top_k_neighbors(Rp, k, inner, n_jobs))

    return float(
        np.mean(
            [_jac_sim_i(idx_R, idx_Rp) for idx_R, idx_Rp in zip(indices_R, indices_Rp)]
        )
    )


def top_k_neighbors(
    R: npt.NDArray,
    k: int,
    inner: str,
    n_jobs: int,
) -> npt.NDArray:
    # k+1 nearest neighbors, because we pass in all the data, which means that a point
    # will be the nearest neighbor to itself. We remove this point from the results and
    # report only the k nearest neighbors distinct from the point itself.
    nns = sklearn.neighbors.NearestNeighbors(
        n_neighbors=k + 1, metric=inner, n_jobs=n_jobs
    )
    nns.fit(R)
    _, nns = nns.kneighbors(R)
    return nns[:, 1:]


def nn_array_to_setlist(nn: npt.NDArray) -> List[Set[int]]:
    return [set(idx) for idx in nn]

In [11]:
import functools
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import get_args
from typing import List
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import Union

import numpy as np
import numpy.typing as npt
import torch
from einops import rearrange
# from loguru import logger

log = logging.getLogger(__name__)


SHAPE_TYPE = Literal["nd", "ntd", "nchw"]

ND_SHAPE, NTD_SHAPE, NCHW_SHAPE = get_args(SHAPE_TYPE)[0], get_args(SHAPE_TYPE)[1], get_args(SHAPE_TYPE)[2]


class SimilarityFunction(Protocol):
    def __call__(  # noqa: E704
        self,
        R: torch.Tensor | npt.NDArray,
        Rp: torch.Tensor | npt.NDArray,
        shape: SHAPE_TYPE,
    ) -> float: ...


class RSMSimilarityFunction(Protocol):
    def __call__(  # noqa: E704
        self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE, n_jobs: int
    ) -> float: ...


@dataclass
class BaseSimilarityMeasure(ABC):
    larger_is_more_similar: bool
    is_symmetric: bool

    is_metric: bool | None = None
    invariant_to_affine: bool | None = None
    invariant_to_invertible_linear: bool | None = None
    invariant_to_ortho: bool | None = None
    invariant_to_permutation: bool | None = None
    invariant_to_isotropic_scaling: bool | None = None
    invariant_to_translation: bool | None = None
    name: str = field(init=False)

    def __post_init__(self):
        self.name = self.__class__.__name__

    @abstractmethod
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        raise NotImplementedError


class FunctionalSimilarityMeasure(BaseSimilarityMeasure):
    @abstractmethod
    def __call__(self, output_a: torch.Tensor | npt.NDArray, output_b: torch.Tensor | npt.NDArray) -> float:
        raise NotImplementedError


@dataclass(kw_only=True)
class RepresentationalSimilarityMeasure(BaseSimilarityMeasure):
    sim_func: SimilarityFunction

    def __call__(
        self,
        R: torch.Tensor | npt.NDArray,
        Rp: torch.Tensor | npt.NDArray,
        shape: SHAPE_TYPE,
    ) -> float:
        return self.sim_func(R, Rp, shape)


class RSMSimilarityMeasure(RepresentationalSimilarityMeasure):
    sim_func: RSMSimilarityFunction

    @staticmethod
    def estimate_good_number_of_jobs(R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray) -> int:
        # RSMs in are NxN (or DxD) so the number of jobs should roughly scale quadratically with increase in N (or D).
        # False! As long as sklearn-native metrics are used, they will use parallel implementations regardless of job
        # count. Each job would spawn their own threads, which leads to oversubscription of cores and thus slowdown.
        # This seems to be not fully correct (n_jobs=2 seems to actually use two cores), but using n_jobs=1 seems the
        # fastest.
        return 1

    def __call__(
        self,
        R: torch.Tensor | npt.NDArray,
        Rp: torch.Tensor | npt.NDArray,
        shape: SHAPE_TYPE,
        n_jobs: Optional[int] = None,
    ) -> float:
        if n_jobs is None:
            n_jobs = self.estimate_good_number_of_jobs(R, Rp)
        return self.sim_func(R, Rp, shape, n_jobs=n_jobs)


def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        return x if isinstance(x, np.ndarray) else x.numpy()

    return list(map(convert, args))


def to_torch_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[torch.Tensor]:
    def convert(x: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
        return x if isinstance(x, torch.Tensor) else torch.from_numpy(x)

    return list(map(convert, args))


def adjust_dimensionality(R: npt.NDArray, Rp: npt.NDArray, strategy="zero_pad") -> Tuple[npt.NDArray, npt.NDArray]:
    D = R.shape[1]
    Dp = Rp.shape[1]
    if strategy == "zero_pad":
        if D - Dp == 0:
            return R, Rp
        elif D - Dp > 0:
            return R, np.concatenate((Rp, np.zeros((Rp.shape[0], D - Dp))), axis=1)
        else:
            return np.concatenate((R, np.zeros((R.shape[0], Dp - D))), axis=1), Rp
    else:
        raise NotImplementedError()


def center_columns(R: npt.NDArray) -> npt.NDArray:
    return R - R.mean(axis=0)[None, :]


def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord="fro")


def normalize_row_norm(R: npt.NDArray) -> npt.NDArray:
    return R / np.linalg.norm(R, ord=2, axis=1, keepdims=True)


def standardize(R: npt.NDArray) -> npt.NDArray:
    return (R - R.mean(axis=0, keepdims=True)) / R.std(axis=0)


def double_center(x: npt.NDArray) -> npt.NDArray:
    return x - x.mean(axis=0, keepdims=True) - x.mean(axis=1, keepdims=True) + x.mean()


def align_spatial_dimensions(R: npt.NDArray, Rp: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]:
    """
    Aligns spatial representations by resizing them to the smallest spatial dimension.
    Subsequent aligned spatial representations are flattened, with the spatial aligned representations
    moving into the *sample* dimension.
    """
    R_re, Rp_re = resize_wh_reps(R, Rp)
    R_re = rearrange(R_re, "n c h w -> (n h w) c")
    Rp_re = rearrange(Rp_re, "n c h w -> (n h w) c")
    if R_re.shape[0] > 5000:
        logger.info(f"Got {R_re.shape[0]} samples in N after flattening. Subsampling to reduce compute.")
        subsample = R_re.shape[0] // 5000
        R_re = R_re[::subsample]
        Rp_re = Rp_re[::subsample]

    return R_re, Rp_re


def average_pool_downsample(R, resize: bool, new_size: tuple[int, int]):
    if not resize:
        return R  # do nothing
    else:
        is_numpy = isinstance(R, np.ndarray)
        R_torch = torch.from_numpy(R) if is_numpy else R
        R_torch = torch.nn.functional.adaptive_avg_pool2d(R_torch, new_size)
        return R_torch.numpy() if is_numpy else R_torch


def resize_wh_reps(R: npt.NDArray, Rp: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]:
    """
    Function for resizing spatial representations that are not the same size.
    Does through fourier transform and resizing.

    Args:
        R: numpy array of shape  [batch_size, height, width, num_channels]
        RP: numpy array of shape [batch_size, height, width, num_channels]

    Returns:
        fft_acts1: numpy array of shape [batch_size, (new) height, (new) width, num_channels]
        fft_acts2: numpy array of shape [batch_size, (new) height, (new) width, num_channels]

    """
    height1, width1 = R.shape[2], R.shape[3]
    height2, width2 = Rp.shape[2], Rp.shape[3]
    if height1 != height2 or width1 != width2:
        height = min(height1, height2)
        width = min(width1, width2)
        new_size = [height, width]
        resize = True
    else:
        height = height1
        width = width1
        new_size = None
        resize = False

    # resize and preprocess with fft
    avg_ds1 = average_pool_downsample(R, resize=resize, new_size=new_size)
    avg_ds2 = average_pool_downsample(Rp, resize=resize, new_size=new_size)
    return avg_ds1, avg_ds2


def fft_resize(images, resize=False, new_size=None):
    """Function for applying DFT and resizing.

    This function takes in an array of images, applies the 2-d fourier transform
    and resizes them according to new_size, keeping the frequencies that overlap
    between the two sizes.

    Args:
              images: a numpy array with shape
                      [batch_size, height, width, num_channels]
              resize: boolean, whether or not to resize
              new_size: a tuple (size, size), with height and width the same

    Returns:
              im_fft_downsampled: a numpy array with shape
                           [batch_size, (new) height, (new) width, num_channels]
    """
    assert len(images.shape) == 4, "expecting images to be" "[batch_size, height, width, num_channels]"
    if resize:
        # FFT --> remove high frequencies --> inverse FFT
        im_complex = images.astype("complex64")
        im_fft = np.fft.fft2(im_complex, axes=(1, 2))
        im_shifted = np.fft.fftshift(im_fft, axes=(1, 2))

        center_width = im_shifted.shape[2] // 2
        center_height = im_shifted.shape[1] // 2
        half_w = new_size[0] // 2
        half_h = new_size[1] // 2
        cropped_fft = im_shifted[
            :, center_height - half_h : center_height + half_h, center_width - half_w : center_width + half_w, :
        ]
        cropped_fft_shifted_back = np.fft.ifft2(cropped_fft, axes=(1, 2))
        return cropped_fft_shifted_back.real
    else:
        return images


class Pipeline:
    def __init__(
        self,
        preprocess_funcs: List[Callable[[npt.NDArray], npt.NDArray]],
        similarity_func: Callable[[npt.NDArray, npt.NDArray, SHAPE_TYPE], float],
    ) -> None:
        self.preprocess_funcs = preprocess_funcs
        self.similarity_func = similarity_func

    def __call__(self, R: npt.NDArray, Rp: npt.NDArray, shape: SHAPE_TYPE) -> float:
        try:
            for preprocess_func in self.preprocess_funcs:
                R = preprocess_func(R)
                Rp = preprocess_func(Rp)
            return self.similarity_func(R, Rp, shape)
        except ValueError as e:
            log.info(f"Pipeline failed: {e}")
            return np.nan

    def __str__(self) -> str:
        def func_name(func: Callable) -> str:
            return func.__name__ if not isinstance(func, functools.partial) else func.func.__name__

        def partial_keywords(func: Callable) -> str:
            if not isinstance(func, functools.partial):
                return ""
            else:
                return str(func.keywords)

        return (
            "Pipeline("
            + (
                "+".join(map(func_name, self.preprocess_funcs))
                + "+"
                + func_name(self.similarity_func)
                + partial_keywords(self.similarity_func)
            )
            + ")"
        )


def flatten(*args: Union[torch.Tensor, npt.NDArray], shape: SHAPE_TYPE) -> List[Union[torch.Tensor, npt.NDArray]]:
    if shape == "ntd":
        return list(map(flatten_nxtxd_to_ntxd, args))
    elif shape == "nd":
        return list(args)
    elif shape == "nchw":
        return list(map(flatten_nxcxhxw_to_nxchw, args))  # Flattening non-trivial for nchw
    else:
        raise ValueError("Unknown shape of representations. Must be one of 'ntd', 'nchw', 'nd'.")


def flatten_nxtxd_to_ntxd(R: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
    R = to_torch_if_needed(R)[0]
    log.debug("Shape before flattening: %s", str(R.shape))
    R = torch.flatten(R, start_dim=0, end_dim=1)
    log.debug("Shape after flattening: %s", str(R.shape))
    return R


def flatten_nxcxhxw_to_nxchw(R: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
    R = to_torch_if_needed(R)[0]
    log.debug("Shape before flattening: %s", str(R.shape))
    R = torch.reshape(R, (R.shape[0], -1))
    log.debug("Shape after flattening: %s", str(R.shape))
    return R

In [12]:
import scipy.optimize

def permutation_procrustes(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
    optimal_permutation_alignment: Optional[Tuple[npt.NDArray, npt.NDArray]] = None,
) -> float:
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)
    R, Rp = adjust_dimensionality(R, Rp)

    if not optimal_permutation_alignment:
        PR, PRp = scipy.optimize.linear_sum_assignment(R.T @ Rp, maximize=True)  # returns column assignments
        optimal_permutation_alignment = (PR, PRp)
    PR, PRp = optimal_permutation_alignment
    return float(np.linalg.norm(R[:, PR] - Rp[:, PRp], ord="fro"))

In [13]:
from typing import Optional
from typing import Union

import numpy as np
import numpy.typing as npt
import scipy.spatial.distance
import scipy.stats
import sklearn.metrics
import torch
# from repsim.measures.utils import flatten
# from repsim.measures.utils import RSMSimilarityMeasure
# from repsim.measures.utils import SHAPE_TYPE
# from repsim.measures.utils import to_numpy_if_needed


def representational_similarity_analysis(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
    inner="correlation",
    outer="spearman",
    n_jobs: Optional[int] = None,
) -> float:
    """Representational similarity analysis

    Args:
        R (Union[torch.Tensor, npt.NDArray]): N x D representation
        Rp (Union[torch.Tensor, npt.NDArray]): N x D' representation
        inner (str, optional): inner similarity function for RSM. Must be one of
            scipy.spatial.distance.pdist identifiers . Defaults to "correlation".
        outer (str, optional): outer similarity function that compares RSMs. Defaults to
             "spearman". Must be one of "spearman", "euclidean"

    Returns:
        float: _description_
    """
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)

    if inner == "correlation":
        # n_jobs only works if metric is in PAIRWISE_DISTANCES as defined in sklearn, i.e., not for correlation.
        # But correlation = 1 - cosine dist of row-centered data, so we use the faster cosine metric and center the data.
        R = R - R.mean(axis=1, keepdims=True)
        S = scipy.spatial.distance.squareform(  # take the lower triangle of RSM
            1 - sklearn.metrics.pairwise_distances(R, metric="cosine", n_jobs=n_jobs),  # type:ignore
            checks=False,
        )
        Rp = Rp - Rp.mean(axis=1, keepdims=True)
        Sp = scipy.spatial.distance.squareform(
            1 - sklearn.metrics.pairwise_distances(Rp, metric="cosine", n_jobs=n_jobs),  # type:ignore
            checks=False,
        )
    elif inner == "euclidean":
        # take the lower triangle of RSM
        S = scipy.spatial.distance.squareform(
            sklearn.metrics.pairwise_distances(R, metric=inner, n_jobs=n_jobs), checks=False
        )
        Sp = scipy.spatial.distance.squareform(
            sklearn.metrics.pairwise_distances(Rp, metric=inner, n_jobs=n_jobs), checks=False
        )
    else:
        raise NotImplementedError(f"{inner=}")

    if outer == "spearman":
        return scipy.stats.spearmanr(S, Sp).statistic  # type:ignore
    elif outer == "euclidean":
        return float(np.linalg.norm(S - Sp, ord=2))
    else:
        raise ValueError(f"Unknown outer similarity function: {outer}")


class RSA(RSMSimilarityMeasure):
    def __init__(self):
        # choice of inner/outer in __call__ if fixed to default values, so these values are always the same
        super().__init__(
            sim_func=representational_similarity_analysis,
            larger_is_more_similar=True,
            is_metric=False,
            is_symmetric=True,
            invariant_to_affine=False,
            invariant_to_invertible_linear=False,
            invariant_to_ortho=False,
            invariant_to_permutation=True,
            invariant_to_isotropic_scaling=True,
            invariant_to_translation=True,
        )

In [14]:
##################################################################################
# Copied from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/cca_core.py
# Copyright 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The core code for applying Canonical Correlation Analysis to deep networks.

This module contains the core functions to apply canonical correlation analysis
to deep neural networks. The main function is get_cca_similarity, which takes in
two sets of activations, typically the neurons in two layers and their outputs
on all of the datapoints D = [d_1,...,d_m] that have been passed through.

Inputs have shape (num_neurons1, m), (num_neurons2, m). This can be directly
applied used on fully connected networks. For convolutional layers, the 3d block
of neurons can either be flattened entirely, along channels, or alternatively,
the dft_ccas (Discrete Fourier Transform) module can be used.

See:
https://arxiv.org/abs/1706.05806
https://arxiv.org/abs/1806.05759
for full details.

"""
import numpy as np
# from repsim.measures.utils import align_spatial_dimensions

num_cca_trials = 5


def positivedef_matrix_sqrt(array):
    """Stable method for computing matrix square roots, supports complex matrices.

    Args:
              array: A numpy 2d array, can be complex valued that is a positive
                     definite symmetric (or hermitian) matrix

    Returns:
              sqrtarray: The matrix square root of array
    """
    w, v = np.linalg.eigh(array)
    #  A - np.dot(v, np.dot(np.diag(w), v.T))
    wsqrt = np.sqrt(w)
    sqrtarray = np.dot(v, np.dot(np.diag(wsqrt), np.conj(v).T))
    return sqrtarray


def remove_small(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon):
    """Takes covariance between X, Y, and removes values of small magnitude.

    Args:
              sigma_xx: 2d numpy array, variance matrix for x
              sigma_xy: 2d numpy array, crossvariance matrix for x,y
              sigma_yx: 2d numpy array, crossvariance matrixy for x,y,
                        (conjugate) transpose of sigma_xy
              sigma_yy: 2d numpy array, variance matrix for y
              epsilon : cutoff value for norm below which directions are thrown
                         away

    Returns:
              sigma_xx_crop: 2d array with low x norm directions removed
              sigma_xy_crop: 2d array with low x and y norm directions removed
              sigma_yx_crop: 2d array with low x and y norm directiosn removed
              sigma_yy_crop: 2d array with low y norm directions removed
              x_idxs: indexes of sigma_xx that were removed
              y_idxs: indexes of sigma_yy that were removed
    """

    x_diag = np.abs(np.diagonal(sigma_xx))
    y_diag = np.abs(np.diagonal(sigma_yy))
    x_idxs = x_diag >= epsilon
    y_idxs = y_diag >= epsilon

    sigma_xx_crop = sigma_xx[x_idxs][:, x_idxs]
    sigma_xy_crop = sigma_xy[x_idxs][:, y_idxs]
    sigma_yx_crop = sigma_yx[y_idxs][:, x_idxs]
    sigma_yy_crop = sigma_yy[y_idxs][:, y_idxs]

    return (sigma_xx_crop, sigma_xy_crop, sigma_yx_crop, sigma_yy_crop, x_idxs, y_idxs)


def compute_ccas(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon, verbose=True):
    """Main cca computation function, takes in variances and crossvariances.

    This function takes in the covariances and cross covariances of X, Y,
    preprocesses them (removing small magnitudes) and outputs the raw results of
    the cca computation, including cca directions in a rotated space, and the
    cca correlation coefficient values.

    Args:
              sigma_xx: 2d numpy array, (num_neurons_x, num_neurons_x)
                        variance matrix for x
              sigma_xy: 2d numpy array, (num_neurons_x, num_neurons_y)
                        crossvariance matrix for x,y
              sigma_yx: 2d numpy array, (num_neurons_y, num_neurons_x)
                        crossvariance matrix for x,y (conj) transpose of sigma_xy
              sigma_yy: 2d numpy array, (num_neurons_y, num_neurons_y)
                        variance matrix for y
              epsilon:  small float to help with stabilizing computations
              verbose:  boolean on whether to print intermediate outputs

    Returns:
              [ux, sx, vx]: [numpy 2d array, numpy 1d array, numpy 2d array]
                            ux and vx are (conj) transposes of each other, being
                            the canonical directions in the X subspace.
                            sx is the set of canonical correlation coefficients-
                            how well corresponding directions in vx, Vy correlate
                            with each other.
              [uy, sy, vy]: Same as above, but for Y space
              invsqrt_xx:   Inverse square root of sigma_xx to transform canonical
                            directions back to original space
              invsqrt_yy:   Same as above but for sigma_yy
              x_idxs:       The indexes of the input sigma_xx that were pruned
                            by remove_small
              y_idxs:       Same as above but for sigma_yy
    """

    (sigma_xx, sigma_xy, sigma_yx, sigma_yy, x_idxs, y_idxs) = remove_small(
        sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon
    )

    numx = sigma_xx.shape[0]
    numy = sigma_yy.shape[0]

    if numx == 0 or numy == 0:
        return (
            [0, 0, 0],
            [0, 0, 0],
            np.zeros_like(sigma_xx),
            np.zeros_like(sigma_yy),
            x_idxs,
            y_idxs,
        )

    if verbose:
        print("adding eps to diagonal and taking inverse")
    sigma_xx += epsilon * np.eye(numx)
    sigma_yy += epsilon * np.eye(numy)
    inv_xx = np.linalg.pinv(sigma_xx)
    inv_yy = np.linalg.pinv(sigma_yy)

    if verbose:
        print("taking square root")
    invsqrt_xx = positivedef_matrix_sqrt(inv_xx)
    invsqrt_yy = positivedef_matrix_sqrt(inv_yy)

    if verbose:
        print("dot products...")
    arr = np.dot(invsqrt_xx, np.dot(sigma_xy, invsqrt_yy))

    if verbose:
        print("trying to take final svd")
    u, s, v = np.linalg.svd(arr)

    if verbose:
        print("computed everything!")

    return [u, np.abs(s), v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs


def sum_threshold(array, threshold):
    """Computes threshold index of decreasing nonnegative array by summing.

    This function takes in a decreasing array nonnegative floats, and a
    threshold between 0 and 1. It returns the index i at which the sum of the
    array up to i is threshold*total mass of the array.

    Args:
              array: a 1d numpy array of decreasing, nonnegative floats
              threshold: a number between 0 and 1

    Returns:
              i: index at which np.sum(array[:i]) >= threshold
    """
    assert (threshold >= 0) and (threshold <= 1), "print incorrect threshold"

    for i in range(len(array)):
        if np.sum(array[:i]) / np.sum(array) >= threshold:
            return i


def create_zero_dict(compute_dirns, dimension):
    """Outputs a zero dict when neuron activation norms too small.

    This function creates a return_dict with appropriately shaped zero entries
    when all neuron activations are very small.

    Args:
              compute_dirns: boolean, whether to have zero vectors for directions
              dimension: int, defines shape of directions

    Returns:
              return_dict: a dict of appropriately shaped zero entries
    """
    return_dict = {}
    return_dict["mean"] = (np.asarray(0), np.asarray(0))
    return_dict["sum"] = (np.asarray(0), np.asarray(0))
    return_dict["cca_coef1"] = np.asarray(0)
    return_dict["cca_coef2"] = np.asarray(0)
    return_dict["idx1"] = 0
    return_dict["idx2"] = 0

    if compute_dirns:
        return_dict["cca_dirns1"] = np.zeros((1, dimension))
        return_dict["cca_dirns2"] = np.zeros((1, dimension))

    return return_dict


def get_cca_similarity(
    acts1,
    acts2,
    epsilon=0.0,
    threshold=0.98,
    compute_coefs=True,
    compute_dirns=False,
    verbose=True,
):
    """The main function for computing cca similarities.

    This function computes the cca similarity between two sets of activations,
    returning a dict with the cca coefficients, a few statistics of the cca
    coefficients, and (optionally) the actual directions.

    Args:
              acts1: (num_neurons1, data_points) a 2d numpy array of neurons by
                     datapoints where entry (i,j) is the output of neuron i on
                     datapoint j.
              acts2: (num_neurons2, data_points) same as above, but (potentially)
                     for a different set of neurons. Note that acts1 and acts2
                     can have different numbers of neurons, but must agree on the
                     number of datapoints

              epsilon: small float to help stabilize computations

              threshold: float between 0, 1 used to get rid of trailing zeros in
                         the cca correlation coefficients to output more accurate
                         summary statistics of correlations.


              compute_coefs: boolean value determining whether coefficients
                             over neurons are computed. Needed for computing
                             directions

              compute_dirns: boolean value determining whether actual cca
                             directions are computed. (For very large neurons and
                             datasets, may be better to compute these on the fly
                             instead of store in memory.)

              verbose: Boolean, whether intermediate outputs are printed

    Returns:
              return_dict: A dictionary with outputs from the cca computations.
                           Contains neuron coefficients (combinations of neurons
                           that correspond to cca directions), the cca correlation
                           coefficients (how well aligned directions correlate),
                           x and y idxs (for computing cca directions on the fly
                           if compute_dirns=False), and summary statistics. If
                           compute_dirns=True, the cca directions are also
                           computed.
    """

    # assert dimensionality equal
    assert acts1.shape[1] == acts2.shape[1], "dimensions don't match"
    # check that acts1, acts2 are transposition
    assert acts1.shape[0] < acts1.shape[1], "input must be number of neurons" "by datapoints"
    return_dict = {}

    # compute covariance with numpy function for extra stability
    numx = acts1.shape[0]
    numy = acts2.shape[0]

    covariance = np.cov(acts1, acts2)
    sigmaxx = covariance[:numx, :numx]
    sigmaxy = covariance[:numx, numx:]
    sigmayx = covariance[numx:, :numx]
    sigmayy = covariance[numx:, numx:]

    # rescale covariance to make cca computation more stable
    xmax = np.max(np.abs(sigmaxx))
    ymax = np.max(np.abs(sigmayy))
    sigmaxx /= xmax
    sigmayy /= ymax
    sigmaxy /= np.sqrt(xmax * ymax)
    sigmayx /= np.sqrt(xmax * ymax)

    ([u, s, v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs) = compute_ccas(
        sigmaxx, sigmaxy, sigmayx, sigmayy, epsilon=epsilon, verbose=verbose
    )

    # if x_idxs or y_idxs is all false, return_dict has zero entries
    if (not np.any(x_idxs)) or (not np.any(y_idxs)):
        return create_zero_dict(compute_dirns, acts1.shape[1])

    if compute_coefs:
        # also compute full coefficients over all neurons
        x_mask = np.dot(x_idxs.reshape((-1, 1)), x_idxs.reshape((1, -1)))
        y_mask = np.dot(y_idxs.reshape((-1, 1)), y_idxs.reshape((1, -1)))

        return_dict["coef_x"] = u.T
        return_dict["invsqrt_xx"] = invsqrt_xx
        return_dict["full_coef_x"] = np.zeros((numx, numx))
        np.place(return_dict["full_coef_x"], x_mask, return_dict["coef_x"])
        return_dict["full_invsqrt_xx"] = np.zeros((numx, numx))
        np.place(return_dict["full_invsqrt_xx"], x_mask, return_dict["invsqrt_xx"])

        return_dict["coef_y"] = v
        return_dict["invsqrt_yy"] = invsqrt_yy
        return_dict["full_coef_y"] = np.zeros((numy, numy))
        np.place(return_dict["full_coef_y"], y_mask, return_dict["coef_y"])
        return_dict["full_invsqrt_yy"] = np.zeros((numy, numy))
        np.place(return_dict["full_invsqrt_yy"], y_mask, return_dict["invsqrt_yy"])

        # compute means
        neuron_means1 = np.mean(acts1, axis=1, keepdims=True)
        neuron_means2 = np.mean(acts2, axis=1, keepdims=True)
        return_dict["neuron_means1"] = neuron_means1
        return_dict["neuron_means2"] = neuron_means2

    if compute_dirns:
        # orthonormal directions that are CCA directions
        cca_dirns1 = (
            np.dot(
                np.dot(return_dict["full_coef_x"], return_dict["full_invsqrt_xx"]),
                (acts1 - neuron_means1),
            )
            + neuron_means1
        )
        cca_dirns2 = (
            np.dot(
                np.dot(return_dict["full_coef_y"], return_dict["full_invsqrt_yy"]),
                (acts2 - neuron_means2),
            )
            + neuron_means2
        )

    # get rid of trailing zeros in the cca coefficients
    idx1 = sum_threshold(s, threshold)
    idx2 = sum_threshold(s, threshold)

    return_dict["cca_coef1"] = s
    return_dict["cca_coef2"] = s
    return_dict["x_idxs"] = x_idxs
    return_dict["y_idxs"] = y_idxs
    # summary statistics
    return_dict["mean"] = (np.mean(s[:idx1]), np.mean(s[:idx2]))
    return_dict["sum"] = (np.sum(s), np.sum(s))

    if compute_dirns:
        return_dict["cca_dirns1"] = cca_dirns1
        return_dict["cca_dirns2"] = cca_dirns2

    return return_dict


def robust_cca_similarity(acts1, acts2, threshold=0.98, epsilon=1e-6, compute_dirns=True):
    """Calls get_cca_similarity multiple times while adding noise.

    This function is very similar to get_cca_similarity, and can be used if
    get_cca_similarity doesn't converge for some pair of inputs. This function
    adds some noise to the activations to help convergence.

    Args:
              acts1: (num_neurons1, data_points) a 2d numpy array of neurons by
                     datapoints where entry (i,j) is the output of neuron i on
                     datapoint j.
              acts2: (num_neurons2, data_points) same as above, but (potentially)
                     for a different set of neurons. Note that acts1 and acts2
                     can have different numbers of neurons, but must agree on the
                     number of datapoints

              threshold: float between 0, 1 used to get rid of trailing zeros in
                         the cca correlation coefficients to output more accurate
                         summary statistics of correlations.

              epsilon: small float to help stabilize computations

              compute_dirns: boolean value determining whether actual cca
                             directions are computed. (For very large neurons and
                             datasets, may be better to compute these on the fly
                             instead of store in memory.)

    Returns:
              return_dict: A dictionary with outputs from the cca computations.
                           Contains neuron coefficients (combinations of neurons
                           that correspond to cca directions), the cca correlation
                           coefficients (how well aligned directions correlate),
                           x and y idxs (for computing cca directions on the fly
                           if compute_dirns=False), and summary statistics. If
                           compute_dirns=True, the cca directions are also
                           computed.
    """

    for trial in range(num_cca_trials):
        try:
            return_dict = get_cca_similarity(acts1, acts2, threshold, compute_dirns)
        except np.linalg.LinAlgError:
            acts1 = acts1 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon
            acts2 = acts2 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon
            if trial + 1 == num_cca_trials:
                raise

    return return_dict
    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/cca_core.py


def top_k_pca_comps(singular_values, threshold=0.99):
    total_variance = np.sum(singular_values**2)
    explained_variance = (singular_values**2) / total_variance
    cumulative_variance = np.cumsum(explained_variance)
    return np.argmax(cumulative_variance >= threshold * total_variance) + 1


def _svcca_original(acts1, acts2):
    # Copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/tutorials/001_Introduction.ipynb
    # Modification: get_cca_similarity is in the same file.
    # Modification: top-k PCA component selection s.t. explained variance > 0.99 total variance
    # Mean subtract activations
    cacts1 = acts1 - np.mean(acts1, axis=1, keepdims=True)
    cacts2 = acts2 - np.mean(acts2, axis=1, keepdims=True)

    # Perform SVD
    U1, s1, V1 = np.linalg.svd(cacts1, full_matrices=False)
    U2, s2, V2 = np.linalg.svd(cacts2, full_matrices=False)

    # top-k PCA components only
    k1 = top_k_pca_comps(s1)
    k2 = top_k_pca_comps(s2)

    svacts1 = np.dot(s1[:k1] * np.eye(k1), V1[:k1])
    # can also compute as svacts1 = np.dot(U1.T[:20], cacts1)
    svacts2 = np.dot(s2[:k2] * np.eye(k2), V2[:k2])
    # can also compute as svacts1 = np.dot(U2.T[:20], cacts2)

    svcca_results = get_cca_similarity(svacts1, svacts2, epsilon=1e-10, verbose=False)
    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/tutorials/001_Introduction.ipynb
    return np.mean(svcca_results["cca_coef1"])


# Copied from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/pwcca.py
# Modification: get_cca_similarity is in the same file.
def compute_pwcca(acts1, acts2, epsilon=0.0):
    """Computes projection weighting for weighting CCA coefficients

    Args:
         acts1: 2d numpy array, shaped (neurons, num_datapoints)
         acts2: 2d numpy array, shaped (neurons, num_datapoints)

    Returns:
         Original cca coefficient mean and weighted mean

    """
    sresults = get_cca_similarity(
        acts1,
        acts2,
        epsilon=epsilon,
        compute_dirns=False,
        compute_coefs=True,
        verbose=False,
    )
    if np.sum(sresults["x_idxs"]) <= np.sum(sresults["y_idxs"]):
        dirns = (
            np.dot(
                sresults["coef_x"],
                (acts1[sresults["x_idxs"]] - sresults["neuron_means1"][sresults["x_idxs"]]),
            )
            + sresults["neuron_means1"][sresults["x_idxs"]]
        )
        coefs = sresults["cca_coef1"]
        acts = acts1
        idxs = sresults["x_idxs"]
    else:
        dirns = (
            np.dot(
                sresults["coef_y"],
                (acts1[sresults["y_idxs"]] - sresults["neuron_means2"][sresults["y_idxs"]]),
            )
            + sresults["neuron_means2"][sresults["y_idxs"]]
        )
        coefs = sresults["cca_coef2"]
        acts = acts2
        idxs = sresults["y_idxs"]
    P, _ = np.linalg.qr(dirns.T)
    weights = np.sum(np.abs(np.dot(P.T, acts[idxs].T)), axis=1)
    weights = weights / np.sum(weights)

    return np.sum(weights * coefs), weights, coefs
    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/pwcca.py


##################################################################################

from typing import Union  # noqa:e402

import numpy.typing as npt  # noqa:e402
import torch  # noqa:e402

# from repsim.measures.utils import (
#     SHAPE_TYPE,
#     flatten,
#     resize_wh_reps,
#     to_numpy_if_needed,
#     RepresentationalSimilarityMeasure,
# )  # noqa:e402


def svcca(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
) -> float:
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)
    return _svcca_original(R.T, Rp.T)


def pwcca(
    R: Union[torch.Tensor, npt.NDArray],
    Rp: Union[torch.Tensor, npt.NDArray],
    shape: SHAPE_TYPE,
) -> float:
    R, Rp = flatten(R, Rp, shape=shape)
    R, Rp = to_numpy_if_needed(R, Rp)
    return compute_pwcca(R.T, Rp.T)[0]


class SVCCA(RepresentationalSimilarityMeasure):
    def __init__(self):
        super().__init__(
            sim_func=svcca,
            larger_is_more_similar=True,
            is_metric=False,
            is_symmetric=True,
            invariant_to_affine=False,
            invariant_to_invertible_linear=False,
            invariant_to_ortho=True,
            invariant_to_permutation=True,
            invariant_to_isotropic_scaling=True,
            invariant_to_translation=True,
        )

    def __call__(self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE) -> float:
        if shape == "nchw":
            # Move spatial dimensions into the sample dimension
            # If not the same spatial dimension, resample via FFT.
            R, Rp = align_spatial_dimensions(R, Rp)
            shape = "nd"

        return self.sim_func(R, Rp, shape)


class PWCCA(RepresentationalSimilarityMeasure):
    def __init__(self):
        super().__init__(
            sim_func=pwcca,
            larger_is_more_similar=True,
            is_metric=False,
            is_symmetric=False,
            invariant_to_affine=False,
            invariant_to_invertible_linear=False,
            invariant_to_ortho=False,
            invariant_to_permutation=False,
            invariant_to_isotropic_scaling=True,
            invariant_to_translation=True,
        )

    def __call__(self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE) -> float:
        if shape == "nchw":
            # Move spatial dimensions into the sample dimension
            # If not the same spatial dimension, resample via FFT.
            R, Rp = align_spatial_dimensions(R, Rp)
            shape = "nd"

        return self.sim_func(R, Rp, shape)

## get rand

In [15]:
def score_rand(num_runs, weight_matrix_np, weight_matrix_2, num_feats, sim_fn, shapereq_bool):
    all_rand_scores = []
    i = 0
    # for i in range(num_runs):
    while i < num_runs:
        try:
            rand_modA_feats = np.random.choice(range(weight_matrix_np.shape[0]), size=num_feats, replace=False).tolist()
            rand_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()

            if shapereq_bool:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats], "nd")
            else:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats])
            all_rand_scores.append(score)
            i += 1
        except:
            continue
    return sum(all_rand_scores) / len(all_rand_scores)

In [16]:
def score_rand_corr(num_runs, weight_matrix_np, weight_matrix_2, num_feats, highest_correlations_indices_AB, sim_fn, shapereq_bool):
    all_rand_scores = []
    i = 0
    # for i in range(num_runs):
    while i < num_runs:
        try:
            rand_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()
            rand_modA_feats = [highest_correlations_indices_AB[index] for index in rand_modB_feats]

            if shapereq_bool:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats], "nd")
            else:
                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats])
            all_rand_scores.append(score)
            i += 1
        except:
            continue
    # print(sum(all_rand_scores) / len(all_rand_scores))
    # plt.hist(all_rand_scores)
    # plt.show()
    return all_rand_scores

In [17]:
import random
def shuffle_rand(num_runs, weight_matrix_np, weight_matrix_2, num_feats, sim_fn, shapereq_bool):
    all_rand_scores = []
    for i in range(num_runs):
        row_idxs = list(range(num_feats))
        random.shuffle(row_idxs)
        if shapereq_bool:
            score = sim_fn(weight_matrix_np, weight_matrix_2[row_idxs], "nd")
        else:
            score = sim_fn(weight_matrix_np, weight_matrix_2[row_idxs])
        all_rand_scores.append(score)
    # return sum(all_rand_scores) / len(all_rand_scores)
    return all_rand_scores

## plot fns

In [18]:
def plot_svcca_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['svcca_paired'] for i in range(0, 12)]
    unpaired_values = [layer_to_dictscores[i]['svcca_unpaired'] for i in range(0, 12)]

    # Plotting configuration
    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    # Increase figure size
    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
    rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

    # Adding labels, title and custom x-axis tick labels
    ax.set_ylabel('SVCCA')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    # Rotate labels and adjust padding
    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            # if i % 2 == 0:  # Label every other bar above
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12) # , rotation=90
            # else:  # Label every other bar below
            #     ax.text(rect.get_x() + rect.get_width()/2., 0,
            #             f'{height:.3f}',
            #             ha='center', va='top', rotation=90)

    label_bars(rects1)
    label_bars(rects2)

    # Adjust layout to prevent cutting off labels
    plt.tight_layout()

    plt.show()

In [19]:
def plot_rsa_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['rsa_paired'] for i in range(0, 12)]
    unpaired_values = [layer_to_dictscores[i]['rsa_unpaired'] for i in range(0, 12)]

    # Plotting configuration
    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    # Increase figure size
    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
    rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

    # Adding labels, title and custom x-axis tick labels
    ax.set_ylabel('RSA')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    # Rotate labels and adjust padding
    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            # if i % 2 == 0:  # Label every other bar above
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12) # , rotation=90
            # else:  # Label every other bar below
            #     ax.text(rect.get_x() + rect.get_width()/2., 0,
            #             f'{height:.3f}',
            #             ha='center', va='top', rotation=90)

    label_bars(rects1)
    label_bars(rects2)

    # Adjust layout to prevent cutting off labels
    plt.tight_layout()

    plt.show()

In [20]:
def plot_meanCorr_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['mean_actv_corr'] for i in range(0, 12)]
    # unpaired_values = [layer_to_dictscores[i]['svcca_unpaired'] for i in range(0, 12)]

    # Plotting configuration
    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    # Increase figure size
    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
    # rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

    # Adding labels, title and custom x-axis tick labels
    ax.set_ylabel('Corr')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    # Rotate labels and adjust padding
    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            # if i % 2 == 0:  # Label every other bar above
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12) # , rotation=90
            # else:  # Label every other bar below
            #     ax.text(rect.get_x() + rect.get_width()/2., 0,
            #             f'{height:.3f}',
            #             ha='center', va='top', rotation=90)

    label_bars(rects1)
    # label_bars(rects2)

    # Adjust layout to prevent cutting off labels
    plt.tight_layout()

    plt.show()

In [21]:
def plot_meanCorr_filt_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['mean_actv_corr_filt'] for i in range(0, 12)]

    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')

    ax.set_ylabel('Corr')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12)

    label_bars(rects1)
    plt.tight_layout()

    plt.show()

In [22]:
def plot_numFeats_afterFilt_byLayer(layer_to_dictscores):
    for key, sub_dict in layer_to_dictscores.items():
        for sub_key, value in sub_dict.items():
            sub_dict[sub_key] = round(value, 4)

    layers = [f'L{i}' for i in range(0, 12)]
    paired_values = [layer_to_dictscores[i]['num_feat_filt'] for i in range(0, 12)]

    x = np.arange(len(layers))  # label locations
    width = 0.35  # width of the bars

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')

    ax.set_ylabel('Num Feats Kept')
    ax.set_title(f'SAEs comparison by Pythia 70m MLP{layer_id} vs 160m MLP Layers')
    ax.set_xticks(x)
    ax.set_xticklabels(layers)
    ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
    ax.legend()

    def label_bars(rects):
        for i, rect in enumerate(rects):
            height = rect.get_height()
            ax.text(rect.get_x() + rect.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=12)

    label_bars(rects1)
    plt.tight_layout()

    plt.show()

In [23]:
# def plot_js_byLayer(layer_to_dictscores):
#     for key, sub_dict in layer_to_dictscores.items():
#         for sub_key, value in sub_dict.items():
#             sub_dict[sub_key] = round(value, 4)

#     layers = [f'L{i}' for i in range(0, 12)]
#     paired_values = [layer_to_dictscores[i]['1-1 jaccard_paired'] for i in range(0, 12)]
#     unpaired_values = [layer_to_dictscores[i]['1-1 jaccard_unpaired'] for i in range(0, 12)]

#     # Plotting configuration
#     x = np.arange(len(layers))  # label locations
#     width = 0.35  # width of the bars

#     # Increase figure size
#     fig, ax = plt.subplots(figsize=(12, 7))  # Slightly increased figure size
#     rects1 = ax.bar(x - width/2, paired_values, width, label='Paired')
#     rects2 = ax.bar(x + width/2, unpaired_values, width, label='Unpaired')

#     # Adding labels, title and custom x-axis tick labels
#     ax.set_ylabel('Jaccard NN')
#     ax.set_title('SAEs comparison by Pythia 70m MLP0 vs 160m MLP Layers')
#     ax.set_xticks(x)
#     ax.set_xticklabels(layers)
#     ax.set_ylim(0, 1)  # Ensuring y-axis is scaled from 0 to 1
#     ax.legend()

#     # Label bars with increased font size and different positioning for paired and unpaired
#     def label_bars(rects, is_paired):
#         for rect in rects:
#             height = rect.get_height()
#             label_height = height + 0.05 if is_paired else height + 0.01
#             ax.text(rect.get_x() + rect.get_width()/2., label_height,
#                     f'{height:.3f}',
#                     ha='center', va='bottom', fontsize=9)  # Increased font size to 12

#     label_bars(rects1, True)   # Paired bars
#     label_bars(rects2, False)  # Unpaired bars

#     # Adjust layout to prevent cutting off labels
#     plt.tight_layout()

#     # Increase y-axis limit to accommodate higher labels
#     ax.set_ylim(0, 1.1)  # Increased from 1 to 1.1

#     plt.show()

## interpret fns

In [24]:
def highest_activating_tokens(
    feature_acts,
    feature_idx: int,
    k: int = 10,  # num batch_seq samples
    batch_tokens=None
): # -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = batch_tokens.shape

    # Get the top k largest activations for only targeted feature
    # need to flatten (batch,seq) into batch*seq first because it's ANY batch_seq, even if in same batch or same pos
    flattened_feature_acts = feature_acts[:, :, feature_idx].reshape(-1)

    top_acts_values, top_acts_indices = flattened_feature_acts.topk(k)
    # top_acts_values should be 1D
    # top_acts_indices should be also be 1D. Now, turn it back to 2D
    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values

In [25]:
def store_top_toks(top_acts_indices, top_acts_values, batch_tokens):
    feat_samps = []
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        new_str_token = tokenizer.decode(batch_tokens[batch_idx, seq_idx]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
        feat_samps.append(new_str_token)
    return feat_samps

In [26]:
def find_indices_with_keyword(fList, keyword):
    """
    Find all indices of fList which contain the keyword in the string at those indices.

    Args:
    fList (list of str): List of strings to search within.
    keyword (str): Keyword to search for within the strings of fList.

    Returns:
    list of int: List of indices where the keyword is found within the strings of fList.
    """
    index_list = []
    for index, split_list in enumerate(fList):
        no_space_list = [i.replace(' ', '').lower() for i in split_list]
        for tok in no_space_list:
            if keyword.lower() == tok:
                index_list.append(index)
        # if keyword in no_space_list:
            # index_list.append(index)
    return index_list

In [27]:
from rich import print as rprint
def display_top_sequences(top_acts_indices, top_acts_values, batch_tokens):
    s = ""
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # s += f'{batch_idx}\n'
        s += f'batchID: {batch_idx}, '
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq_start = max(seq_idx - 5, 0)
        seq_end = min(seq_idx + 5, batch_tokens.shape[1])
        seq = ""
        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
        for i in range(seq_start, seq_end):
            # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            if i == seq_idx:
                new_str_token = f"[bold u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        s += f'Act = {value:.2f}, Seq = "{seq}"\n'

    rprint(s)

In [28]:
# def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):
#     feat_samps = []
#     for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
#         # Get the sequence as a string (with some padding on either side of our sequence)
#         seq_start = max(seq_idx - 5, 0)
#         seq_end = min(seq_idx + 5, batch_tokens.shape[1])
#         seq = ""
#         # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
#         for i in range(seq_start, seq_end):
#             # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
#             new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
#             # if i == seq_idx:
#             #     new_str_token = f"[bold u dark_orange]{new_str_token}[/]"
#             seq += new_str_token
#         feat_samps.append(seq)
#     return feat_samps

In [29]:
# def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
#     feat_list = []
#     for feat_ind, top_seqs_lst in enumerate(fList_seqs):
#         for seq in top_seqs_lst:
#             split_list = seq.split(' ')
#             flag = False
#             for tok in split_list:
#                 if keyword.lower() == tok:
#                     feat_list.append(feat_ind)
#                     flag = True
#                     break
#             if flag:
#                 break
#     return feat_list

In [30]:
# def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):
#     feat_samps = []
#     for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
#         # Get the sequence with padding
#         seq_start = max(seq_idx - 2, 0)
#         seq_end = min(seq_idx + 2, batch_tokens.shape[1])
#         seq = ""
#         tokens = []
#         for i in range(seq_start, seq_end):
#             token_id = batch_tokens[batch_idx, i].item()
#             new_str_token = tokenizer.decode([token_id]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
#             tokens.append((new_str_token, i))
#             seq += new_str_token
#         # Store the sequence, tokens, and seq_idx
#         feat_samps.append((seq, tokens, seq_idx))
#     return feat_samps

# def get_word(tokens, idx_in_tokens):
#     # Initialize the word with the token at seq_idx
#     word_tokens = [tokens[idx_in_tokens][0]]
#     # Move backwards to find the start of the word
#     i = idx_in_tokens - 1
#     while i >= 0:
#         token_str, _ = tokens[i]
#         if token_str.startswith(' '):
#             break
#         word_tokens.insert(0, token_str)
#         i -= 1
#     # Move forwards to find the end of the word
#     i = idx_in_tokens + 1
#     while i < len(tokens):
#         token_str, _ = tokens[i]
#         if token_str.startswith(' '):
#             break
#         word_tokens.append(token_str)
#         i += 1
#     # Reconstruct the word and remove any spaces
#     word = ''.join(word_tokens).replace(' ', '')
#     return word


# def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
#     feat_list = []
#     for feat_ind, top_seqs_lst in enumerate(fList_seqs):
#         for seq, tokens, seq_idx in top_seqs_lst:
#             # Find the position of seq_idx in tokens
#             idx_in_tokens = None
#             for i, (token_str, idx) in enumerate(tokens):
#                 if idx == seq_idx:
#                     idx_in_tokens = i
#                     break
#             if idx_in_tokens is None:
#                 continue  # seq_idx not found in tokens

#             # Reconstruct the word containing seq_idx
#             word = get_word(tokens, idx_in_tokens)

#             # Compare the reconstructed word with the keyword
#             if word.lower() == keyword.lower():
#                 feat_list.append(feat_ind)
#                 break  # Proceed to the next feature index
#     return feat_list


In [31]:
def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):
    feat_samps = []
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq_start = max(seq_idx - 2, 0)
        seq_end = min(seq_idx + 2, batch_tokens.shape[1])
        seq = ""
        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
        for i in range(seq_start, seq_end):
            # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            if i == seq_idx:
                topTok = new_str_token
            seq += new_str_token
        feat_samps.append( (seq, topTok) )
    return feat_samps

In [32]:
def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
    feat_list = []
    for feat_ind, top_seqs_andToks_lst in enumerate(fList_seqs):
        for top_seqs_andToks in top_seqs_andToks_lst:
            seq = top_seqs_andToks[0]
            topTok = top_seqs_andToks[1].replace(' ', '').lower()
            if keyword.lower() != topTok:
                continue
            split_list = seq.split(' ')
            flag = False
            for word in split_list:
                word = word.replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('\\n','')

                if keyword.lower() == word.lower():
                    feat_list.append(feat_ind)
                    flag = True
                    break
            if flag:
                break
    return feat_list

## get concept space features

In [33]:
def get_mixed_feats(fList_model_B, corr_inds, keywords):
    mixed_modA_feats = []
    mixed_modB_feats = []
    added_modA_feats = set()  # To track which modA feats have been added
    added_modB_feats = set()  # To track which modB feats have been added

    for kw in keywords:
        modB_feats = find_indices_with_keyword(fList_model_B, kw)
        for index in modB_feats:
            modA_feat = corr_inds[index]
            modB_feat = index

            # Check if the feature has already been added to maintain uniqueness
            if modA_feat not in added_modA_feats and modB_feat not in added_modB_feats:
                mixed_modA_feats.append(modA_feat)
                mixed_modB_feats.append(modB_feat)
                added_modA_feats.add(modA_feat)
                added_modB_feats.add(modB_feat)

    print("Unique modA feats: ", len(mixed_modA_feats))
    print("Unique modB feats: ", len(mixed_modB_feats))
    return mixed_modA_feats, mixed_modB_feats

In [34]:
def get_mixed_feats_with_kwList(fList_model_B, corr_inds, keywords):
    mixed_modA_feats = []
    mixed_modB_feats = []
    added_modA_feats = set()  # To track which modA feats have been added
    added_modB_feats = set()  # To track which modB feats have been added
    keywords_to_feats = {kw: 0 for kw in keywords} # kw : count

    for kw in keywords:
        modB_feats = find_indices_with_keyword(fList_model_B, kw)
        for index in modB_feats:
            modA_feat = corr_inds[index]
            modB_feat = index

            # Check if the feature has already been added to maintain uniqueness
            if modA_feat not in added_modA_feats and modB_feat not in added_modB_feats:
                mixed_modA_feats.append(modA_feat)
                mixed_modB_feats.append(modB_feat)
                added_modA_feats.add(modA_feat)
                added_modB_feats.add(modB_feat)
                keywords_to_feats[kw] += 1

    print("Unique modA feats: ", len(mixed_modA_feats))
    print("Unique modB feats: ", len(mixed_modB_feats))
    return mixed_modA_feats, mixed_modB_feats, keywords_to_feats

## get actv fns

In [35]:
# def get_weights_and_acts(name, cfg_dict, layer_id, outputs):
def get_weights_and_acts(name, layer_id, outputs):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    hookpoint = "layers." + str(layer_id)

    sae = Sae.load_from_hub(name, hookpoint=hookpoint, device=device)

    weight_matrix = sae.W_dec.cpu().detach().numpy()

    with torch.inference_mode():
        # reshaped_activations_A = sae.pre_acts(outputs.to("cuda"))
        reshaped_activations = sae.pre_acts(outputs.hidden_states[layer_id].to("cuda"))

    first_dim_reshaped = reshaped_activations.shape[0] * reshaped_activations.shape[1]
    reshaped_activations = reshaped_activations.reshape(first_dim_reshaped, reshaped_activations.shape[-1]).cpu()

    # return weight_matrix_np, reshaped_activations_A, orig
    return weight_matrix, reshaped_activations

In [36]:
def count_zero_columns(tensor):
    # Check if all elements in each column are zero
    zero_columns = np.all(tensor == 0, axis=0)
    # Count True values in the zero_columns array
    zero_cols_indices = np.where(zero_columns)[0]
    return np.sum(zero_columns), zero_cols_indices

## run expm fns

In [37]:
def run_expm(layer_id, outputs, outputs_2):
    layer_to_dictscores = {}

    name = "EleutherAI/sae-pythia-70m-32k"
    cfg_dict = {"expansion_factor": 32, "normalize_decoder": True, "num_latents": 32768, "k": 16, "d_in": 512}
    weight_matrix_np, reshaped_activations_A = get_weights_and_acts(name, cfg_dict, layer_id, outputs)
    # zero_cols_count, zero_cols_indices = count_zero_columns(reshaped_activations_A.cpu().numpy())
    # print("Number of zero columns:", zero_cols_count) #, zero_cols_indices

    name = "EleutherAI/sae-pythia-160m-32k"
    for layerID_2 in range(0, 12): # 0, 12
        dictscores = {}

        # redef
        cfg_dict = {"expansion_factor": 32, "normalize_decoder": True, "num_latents": 32768, "k": 16, "d_in": 768}
        weight_matrix_2, reshaped_activations_B = get_weights_and_acts(name, cfg_dict, layerID_2, outputs_2)


        highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(reshaped_activations_A, reshaped_activations_B)

        num_unq_pairs = len(list(set(highest_correlations_indices_AB)))
        print("% unique: ", num_unq_pairs / len(highest_correlations_indices_AB))

        dictscores["mean_actv_corr"] = sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

        ###########
        # filter

        sorted_feat_counts = Counter(highest_correlations_indices_AB).most_common()
        # kept_modA_feats = [feat_ID for feat_ID, count in sorted_feat_counts if count <= 100]
        kept_modA_feats = [feat_ID for feat_ID, count in sorted_feat_counts if count <= 1]

        filt_corr_ind_A = []
        filt_corr_ind_B = []
        seen = set()
        for ind_B, ind_A in enumerate(highest_correlations_indices_AB):
            if ind_A in kept_modA_feats:
                filt_corr_ind_A.append(ind_A)
                filt_corr_ind_B.append(ind_B)
            elif ind_A not in seen:  # only keep one if it's over count X
                seen.add(ind_A)
                filt_corr_ind_A.append(ind_A)
                filt_corr_ind_B.append(ind_B)
        # num_unq_pairs = len(list(set(filt_corr_ind_A)))
        # print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
        # print("num feats after filt: ", len(filt_corr_ind_A))

        new_highest_correlations_indices_A = []
        new_highest_correlations_indices_B = []
        new_highest_correlations_values = []

        for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
            val = highest_correlations_values_AB[ind_B]
            if val > 0:
                new_highest_correlations_indices_A.append(ind_A)
                new_highest_correlations_indices_B.append(ind_B)
                new_highest_correlations_values.append(val)

        # print("% unique after rmv 0s: ", num_unq_pairs / len(new_highest_correlations_indices_A))
        # print("num feats after rmv 0s: ", len(new_highest_correlations_indices_A))
        dictscores["num_feat_kept"] = len(new_highest_correlations_indices_A)

        dictscores["mean_actv_corr_filt"] = sum(new_highest_correlations_values) / len(new_highest_correlations_values)

        ###########
        # sim tests

        # # num_feats = len(filt_corr_ind_A)
        num_feats = len(new_highest_correlations_indices_A)

        # dictscores["svcca_paired"] = svcca(weight_matrix_np[filt_corr_ind_A], weight_matrix_2[filt_corr_ind_B], "nd")
        dictscores["svcca_paired"] = svcca(weight_matrix_np[new_highest_correlations_indices_A], weight_matrix_2[new_highest_correlations_indices_B], "nd")

        dictscores["svcca_unpaired"] = score_rand(weight_matrix_np, weight_matrix_2, num_feats,
                                                  svcca, shapereq_bool=True)

        dictscores["rsa_paired"] = representational_similarity_analysis(weight_matrix_np[new_highest_correlations_indices_A], weight_matrix_2[new_highest_correlations_indices_B], "nd")
        dictscores["rsa_unpaired"] = score_rand(weight_matrix_np, weight_matrix_2, num_feats,
                                                  representational_similarity_analysis, shapereq_bool=True)

        print("Layer: " + str(layerID_2))
        for key, value in dictscores.items():
            print(key + ": " + str(value))
        print("\n")

        layer_to_dictscores[layerID_2] = dictscores
    return layer_to_dictscores

# load data

In [38]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]



In [39]:
from datasets import load_dataset
# dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True)

openwebtext.py:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


In [40]:
batch_size = 100
maxseqlen = 300

def get_next_batch(dataset_iter, batch_size=100):
    batch = []
    for _ in range(batch_size):
        try:
            sample = next(dataset_iter)
            batch.append(sample['text'])
        except StopIteration:
            break
    return batch

dataset_iter = iter(dataset)
batch = get_next_batch(dataset_iter, batch_size)

# Tokenize the batch
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=maxseqlen)

# load models

In [41]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m")
model_2 = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]

## get LLM actvs

In [42]:
with torch.inference_mode():
    outputs = model(**inputs, output_hidden_states=True)
    outputs_2 = model_2(**inputs, output_hidden_states=True)

# MLP3 vs MLP5

## sae actvs

In [43]:
layer_id = 3
name = "EleutherAI/sae-pythia-70m-32k"
hookpoint = "layers." + str(layer_id)
sae = Sae.load_from_hub(name, hookpoint=hookpoint, device=device)

with torch.inference_mode():
    outputs = model(**inputs, output_hidden_states=True)
    feature_acts_A = sae.pre_acts(outputs.hidden_states[layer_id].to('cuda'))

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

layers.3/cfg.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/134M [00:00<?, ?B/s]



In [44]:
first_dim_reshaped = feature_acts_A.shape[0] * feature_acts_A.shape[1]
reshaped_activations_A = feature_acts_A.reshape(first_dim_reshaped, feature_acts_A.shape[-1]).cpu()

In [45]:
weight_matrix_np = sae.W_dec.cpu().detach().numpy()

In [46]:
layer_id_2 = 5
name = "EleutherAI/sae-pythia-160m-32k"
hookpoint = "layers." + str(layer_id_2)
sae_2 = Sae.load_from_hub(name, hookpoint=hookpoint, device=device)

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

layers.5/cfg.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

sae.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]



In [47]:
with torch.inference_mode():
    feature_acts_B = sae_2.pre_acts(outputs_2.hidden_states[layer_id_2].to("cuda"))

In [48]:
first_dim_reshaped = feature_acts_B.shape[0] * feature_acts_B.shape[1]
reshaped_activations_B = feature_acts_B.reshape(first_dim_reshaped, feature_acts_B.shape[-1]).cpu()

In [49]:
weight_matrix_2 = sae_2.W_dec.cpu().detach().numpy()

## get labels

In [None]:
# store feature : lst of top strs
fList_model_B_seqs = []
samp_m = 5

for feature_idx in range(feature_acts_B.shape[-1]):
# for feature_idx in range(5):
    if feature_idx % 5000 == 0:
        print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])
    fList_model_B_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )

Feature:  0
Feature:  5000
Feature:  10000
Feature:  15000
Feature:  20000
Feature:  25000
Feature:  30000


In [None]:
# store feature : lst of top strs
fList_model_A_seqs = []
samp_m = 5

for feature_idx in range(feature_acts_A.shape[-1]):
    if feature_idx % 5000 == 0:
        print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx, samp_m, batch_tokens= inputs['input_ids'])
    fList_model_A_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )

Feature:  0
Feature:  5000
Feature:  10000
Feature:  15000
Feature:  20000
Feature:  25000
Feature:  30000


In [None]:
# import pdb

# def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
#     feat_list = []
#     for feat_ind, top_seqs_andToks_lst in enumerate(fList_seqs):
#         for top_seqs_andToks in top_seqs_andToks_lst:
#             seq = top_seqs_andToks[0]
#             topTok = top_seqs_andToks[1].replace(' ', '').lower()
#             if keyword.lower() != topTok:
#                 continue
#             split_list = seq.split(' ')
#             flag = False
#             for word in split_list:
#                 word = word.replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('\\n','')

#                 # pdb.set_trace()
#                 if keyword.lower() == word:
#                     feat_list.append(feat_ind)
#                     flag = True
#                     break
#             if flag:
#                 break
#     return feat_list

In [None]:
keyword = 'man'
modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, keyword)
# index_list = []
# for index, top_seqs_lst in enumerate(fList_model_B_seqs):
#     for seq in top_seqs_lst:
#         split_list = seq.split(' ')
#         for tok in split_list:
#             if keyword.lower() == tok:
#                 index_list.append(index)

In [None]:
modB_feats

[3998, 5125, 8855, 19286, 19325, 26791, 27254]

In [None]:
samp_m = 5
feature_idx_B = modB_feats[0]
print('Model B Feature: ', feature_idx_B)
ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

Model B Feature:  3998


## save labels

In [None]:
import pickle
from google.colab import files

modeltype = 'pythia70m'
with open(f'fList_L{layer_id}_{modeltype}.pkl', 'wb') as f:
    pickle.dump(fList_model_A_seqs, f)
files.download(f'fList_L{layer_id}_{modeltype}.pkl')

modeltype = 'pythia160m'
with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:
    pickle.dump(fList_model_B_seqs, f)
files.download(f'fList_L{layer_id_2}_{modeltype}.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## load labels

In [50]:
import pickle
from google.colab import files

In [51]:
modeltype = 'pythia70m'
with open(f'fList_model_A_L{layer_id}_{modeltype}.pkl', 'rb') as f:
    fList_model_A_seqs = pickle.load(f)

In [52]:
modeltype = 'pythia160m'
with open(f'fList_model_B_L{layer_id_2}_{modeltype}.pkl', 'rb') as f:
    fList_model_B_seqs = pickle.load(f)

# famous names

In [None]:
new_keywords = [
    "Einstein", "Newton", "Darwin", "Curie", "Galileo", "Tesla", "Hawking", "Feynman", "Pasteur", "Mendel",
    "Shakespeare", "Dickens", "Hemingway", "Austen", "Orwell", "Tolkien", "Faulkner", "Poe", "Joyce", "Steinbeck",
    "Churchill", "Gandhi", "Mandela", "Napoleon", "Cromwell",
    "Picasso", "Rembrandt", "Michelangelo", "Da Vinci", "Van Gogh", "Monet", "Dali", "Matisse", "Warhol", "Pollock",
    "Beethoven", "Mozart", "Bach", "Chopin", "Wagner", "Tchaikovsky", "Stravinsky", "Vivaldi", "Verdi", "Debussy",
    "Obama", "Clinton", "Reagan", "Thatcher", "Putin", "Merkel", "Castro", "Hitler", "Mao", "Stalin",
    "Freud", "Jung", "Pavlov", "Skinner", "Piaget", "Maslow", "Erikson", "Rogers", "Chomsky", "Bandura",
    "Pele", "Ronaldo", "Messi", "Federer", "Williams", "Phelps", "Brady",
    "Chaplin", "Hitchcock", "Kubrick", "Spielberg", "Scorsese", "Tarantino", "Lucas", "Cameron", "Coppola", "Eastwood"
]


In [None]:
import pdb

def find_indices_with_keyword_bySeqs(fList_seqs, keyword):
    feat_list = []
    for feat_ind, top_seqs_andToks_lst in enumerate(fList_seqs):
        for top_seqs_andToks in top_seqs_andToks_lst:
            seq = top_seqs_andToks[0]
            topTok = top_seqs_andToks[1].replace(' ', '').lower()
            if keyword.lower() != topTok:
                continue
            split_list = seq.split(' ')
            flag = False
            for word in split_list:
                word = word.replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('\\n','')
                # if feat_ind == 29656 and word == 'Clinton':
                #     pdb.set_trace()
                if keyword == word: # .lower()
                    feat_list.append(feat_ind)
                    flag = True
                    break
            if flag:
                break
    return feat_list

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

88
46


### run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.4782608695652174
22


0.2918984379781329

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 88])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 46])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


22

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.27428010150962634

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.17448346566995906


0.21

### filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

12

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.537055072242829

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.2391079341155493


0.076

### interpret

In [None]:
samp_m = 5
num_feats = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:num_feats], filt_corr_ind_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.22204378247261047
Model A Feature:  29656


Model B Feature:  3460


--------------------------------------------------
Correlation: 0.17094212770462036
Model A Feature:  28762


Model B Feature:  30853


--------------------------------------------------
Correlation: 0.7401432394981384
Model A Feature:  27915


Model B Feature:  21753


--------------------------------------------------
Correlation: 0.20684514939785004
Model A Feature:  2055


Model B Feature:  651


--------------------------------------------------
Correlation: 0.12329405546188354
Model A Feature:  23151


Model B Feature:  24209


--------------------------------------------------


In [None]:
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.22204378247261047
Model A Feature:  29656


Model B Feature:  3460


--------------------------------------------------
Correlation: 0.7401432394981384
Model A Feature:  27915


Model B Feature:  21753


--------------------------------------------------
Correlation: 0.20684514939785004
Model A Feature:  2055


Model B Feature:  651


--------------------------------------------------
Correlation: 0.28230422735214233
Model A Feature:  19959


Model B Feature:  7445


--------------------------------------------------
Correlation: 0.31189948320388794
Model A Feature:  3847


Model B Feature:  299


--------------------------------------------------


# numbers

In [None]:
new_keywords = [
    "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
]


In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

336
276


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.286231884057971
79


0.4362414755413066

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 336])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 276])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


79

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.07814231151675255

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.08756744847695008


0.47

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

43

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.11785725336290004

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.14278365868811696


0.52

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

25

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.16599260285236234

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.19243474774769276


0.56

## interpret

In [None]:
samp_m = 5
num_feats = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:num_feats], filt_corr_ind_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.6352710127830505
Model A Feature:  21606


Model B Feature:  20492


--------------------------------------------------
Correlation: 0.16425354778766632
Model A Feature:  10018


Model B Feature:  7181


--------------------------------------------------
Correlation: 0.04314814507961273
Model A Feature:  5576


Model B Feature:  2577


--------------------------------------------------
Correlation: 0.22564321756362915
Model A Feature:  28932


Model B Feature:  2066


--------------------------------------------------
Correlation: 0.23793494701385498
Model A Feature:  31659


Model B Feature:  5656


--------------------------------------------------


In [None]:
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.6352710127830505
Model A Feature:  21606


Model B Feature:  20492


--------------------------------------------------
Correlation: 0.22564321756362915
Model A Feature:  28932


Model B Feature:  2066


--------------------------------------------------
Correlation: 0.23793494701385498
Model A Feature:  31659


Model B Feature:  5656


--------------------------------------------------
Correlation: 0.6850972175598145
Model A Feature:  7002


Model B Feature:  7715


--------------------------------------------------
Correlation: 0.5412429571151733
Model A Feature:  14247


Model B Feature:  25245


--------------------------------------------------


# numerical

In [None]:
new_keywords = [
    "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
    "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen", "seventeen", "eighteen", "nineteen", "twenty",
    "hundred", "thousand", "million", "billion", "trillion",
    "integer", "fraction", "decimal", "percentage", "ratio",
    "numeral", "digit", "prime",
    "sum", "difference", "factor", "multiple",
    "total", "count", "measure", "dozen", "score", "unit"]

In [None]:
len(new_keywords)

43

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

445
349


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.3008595988538682
105


0.43835264851037614

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 445])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 349])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


105

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.044547577974271585

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.08325984621783215


0.71

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

58

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.22527803843444594

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.09560070693014319


0.06

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

45

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.264247529318192

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.12425018261807211


0.086

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

34

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.42145435078523086

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(10000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.139084667507807


0.013

## interpret

In [None]:
samp_m = 5
num_feats = 5
for subset_feature_idx_A, subset_feature_idx_B in zip(filt_corr_ind_A[:num_feats], filt_corr_ind_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.635269284248352
Model A Feature:  21606


Model B Feature:  20492


--------------------------------------------------
Correlation: 0.22564373910427094
Model A Feature:  28932


Model B Feature:  2066


--------------------------------------------------
Correlation: 0.6002208590507507
Model A Feature:  30523


Model B Feature:  8231


--------------------------------------------------
Correlation: 0.3215320110321045
Model A Feature:  1996


Model B Feature:  14388


--------------------------------------------------
Correlation: 0.6463798880577087
Model A Feature:  21409


Model B Feature:  16442


--------------------------------------------------


In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.635269284248352
Model A Feature:  21606


Model B Feature:  20492


--------------------------------------------------
Correlation: 0.6002208590507507
Model A Feature:  30523


Model B Feature:  8231


--------------------------------------------------
Correlation: 0.6463798880577087
Model A Feature:  21409


Model B Feature:  16442


--------------------------------------------------
Correlation: 0.593344509601593
Model A Feature:  18487


Model B Feature:  14454


--------------------------------------------------
Correlation: 0.5502880215644836
Model A Feature:  14247


Model B Feature:  124


--------------------------------------------------
Correlation: 0.6651091575622559
Model A Feature:  26187


Model B Feature:  28832


--------------------------------------------------
Correlation: 0.6804549694061279
Model A Feature:  25149


Model B Feature:  4264


--------------------------------------------------
Correlation: 0.4422963559627533
Model A Feature:  10969


Model B Feature:  8370


--------------------------------------------------
Correlation: 0.522854208946228
Model A Feature:  15748


Model B Feature:  26833


--------------------------------------------------
Correlation: 0.4858936667442322
Model A Feature:  11978


Model B Feature:  22875


--------------------------------------------------
Correlation: 0.647077202796936
Model A Feature:  13073


Model B Feature:  6542


--------------------------------------------------
Correlation: 0.4130925238132477
Model A Feature:  18446


Model B Feature:  25231


--------------------------------------------------
Correlation: 0.5227394700050354
Model A Feature:  22093


Model B Feature:  6810


--------------------------------------------------
Correlation: 0.45957663655281067
Model A Feature:  13469


Model B Feature:  760


--------------------------------------------------
Correlation: 0.8934230208396912
Model A Feature:  22674


Model B Feature:  13084


--------------------------------------------------
Correlation: 0.6943512558937073
Model A Feature:  27234


Model B Feature:  31532


--------------------------------------------------
Correlation: 0.9623398184776306
Model A Feature:  3724


Model B Feature:  888


--------------------------------------------------
Correlation: 0.44172680377960205
Model A Feature:  2023


Model B Feature:  29597


--------------------------------------------------
Correlation: 0.8934170603752136
Model A Feature:  23979


Model B Feature:  17340


--------------------------------------------------
Correlation: 0.420509397983551
Model A Feature:  30764


Model B Feature:  19418


--------------------------------------------------
Correlation: 0.40344521403312683
Model A Feature:  24453


Model B Feature:  7151


--------------------------------------------------
Correlation: 0.6650091409683228
Model A Feature:  13031


Model B Feature:  13359


--------------------------------------------------
Correlation: 0.6294023990631104
Model A Feature:  22228


Model B Feature:  15597


--------------------------------------------------
Correlation: 0.4247927665710449
Model A Feature:  16014


Model B Feature:  9470


--------------------------------------------------
Correlation: 0.41809219121932983
Model A Feature:  23025


Model B Feature:  7427


--------------------------------------------------
Correlation: 0.4593101739883423
Model A Feature:  21196


Model B Feature:  32061


--------------------------------------------------
Correlation: 0.6085872054100037
Model A Feature:  5764


Model B Feature:  13632


--------------------------------------------------
Correlation: 0.6840933561325073
Model A Feature:  10548


Model B Feature:  5531


--------------------------------------------------
Correlation: 0.8304700255393982
Model A Feature:  7002


Model B Feature:  28083


--------------------------------------------------
Correlation: 0.9462249875068665
Model A Feature:  10973


Model B Feature:  19967


--------------------------------------------------
Correlation: 0.7489441633224487
Model A Feature:  19848


Model B Feature:  32387


--------------------------------------------------
Correlation: 0.6915963888168335
Model A Feature:  18973


Model B Feature:  18120


--------------------------------------------------
Correlation: 0.8120508193969727
Model A Feature:  2802


Model B Feature:  30584


--------------------------------------------------
Correlation: 0.4436831474304199
Model A Feature:  30638


Model B Feature:  6072


--------------------------------------------------


# people (filter out double meanings)

In [None]:
new_keywords = [
    "man", "girl", "boy", "kid", "dad", "mom", "son", "sis", "bro",
    "chief", "priest", "king", "queen", "duke", "lord", "friend", "clerk", "coach",
    "nurse", "doc", "maid", "clown", "guest", "peer",
    "punk", "nerd", "jock", "chief"]

In [None]:
len(new_keywords)

28

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

193
67


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.4626865671641791
31


0.39420796391456875

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


31

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.4710842687171789

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.14571578517955402


0.007

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

21

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.4992118372596189

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.1779787944609878


0.018

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

18

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5342001674393785

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.19393906051429885


0.023

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

16

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5885025975218607

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.20655776063981537


0.015

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5085517764091492
Model A Feature:  21818


Model B Feature:  17922


--------------------------------------------------
Correlation: 0.8952447772026062
Model A Feature:  7189


Model B Feature:  16006


--------------------------------------------------
Correlation: 0.7079083323478699
Model A Feature:  22947


Model B Feature:  29448


--------------------------------------------------
Correlation: 0.667935311794281
Model A Feature:  16127


Model B Feature:  5130


--------------------------------------------------
Correlation: 0.437524676322937
Model A Feature:  1541


Model B Feature:  26254


--------------------------------------------------
Correlation: 0.7191550731658936
Model A Feature:  2154


Model B Feature:  14863


--------------------------------------------------
Correlation: 0.6518343687057495
Model A Feature:  9620


Model B Feature:  8734


--------------------------------------------------
Correlation: 0.7343308925628662
Model A Feature:  10245


Model B Feature:  30752


--------------------------------------------------
Correlation: 0.867409884929657
Model A Feature:  17268


Model B Feature:  16672


--------------------------------------------------
Correlation: 0.7894371151924133
Model A Feature:  10581


Model B Feature:  29855


--------------------------------------------------
Correlation: 0.7545995116233826
Model A Feature:  15585


Model B Feature:  9129


--------------------------------------------------
Correlation: 0.40771543979644775
Model A Feature:  13991


Model B Feature:  19002


--------------------------------------------------
Correlation: 0.5599328279495239
Model A Feature:  29658


Model B Feature:  13116


--------------------------------------------------
Correlation: 0.8603036403656006
Model A Feature:  9027


Model B Feature:  14526


--------------------------------------------------
Correlation: 0.6108449101448059
Model A Feature:  20545


Model B Feature:  3016


--------------------------------------------------
Correlation: 0.5494723916053772
Model A Feature:  3677


Model B Feature:  29906


--------------------------------------------------


## which feats which keywords

In [None]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [None]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_A)

17

In [None]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

12

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  king | Count:  2
keyword:  priest | Count:  2
keyword:  dad | Count:  2
keyword:  boy | Count:  2
keyword:  coach | Count:  2
keyword:  son | Count:  1
keyword:  man | Count:  1
keyword:  girl | Count:  1
keyword:  guest | Count:  1
keyword:  friend | Count:  1
keyword:  peer | Count:  1
keyword:  mom | Count:  1


In [None]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_B)

17

In [None]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

11

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  dad | Count:  3
keyword:  king | Count:  2
keyword:  priest | Count:  2
keyword:  mom | Count:  2
keyword:  coach | Count:  2
keyword:  son | Count:  1
keyword:  girl | Count:  1
keyword:  boy | Count:  1
keyword:  guest | Count:  1
keyword:  friend | Count:  1
keyword:  peer | Count:  1


# nature

In [None]:
new_keywords = [
    "tree", "grass", "stone", "rock", "cliff", "hill",
    "dirt", "sand", "mud", "wind", "storm", "rain", "cloud", "sun",
    "moon", "leaf", "branch", "twig", "root", "bark", "seed",
    "tide", "lake", "pond", "creek", "sea", "wood", "field",
    "shore", "snow", "ice", "flame", "fire", "fog", "dew", "hail",
    "sky", "earth", "glade", "cave", "peak", "ridge", "dust", "air",
    "mist", "heat"]

In [None]:
len(new_keywords)

46

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

199
108


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.42592592592592593
46


0.30606323670319935

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


46

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5004843507297421

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.11532482002826273


0.001

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

22

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.3759666881398435

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.1823636678544841


0.111

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

18

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.32595509070667195

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.1939148028409779


0.184

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

15

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7892478040221264

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.21889935583890183


0.0

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.7750056982040405
Model A Feature:  1906


Model B Feature:  19510


--------------------------------------------------
Correlation: 0.5942947864532471
Model A Feature:  2946


Model B Feature:  15936


--------------------------------------------------
Correlation: 0.8092045187950134
Model A Feature:  17935


Model B Feature:  587


--------------------------------------------------
Correlation: 0.8928154110908508
Model A Feature:  30783


Model B Feature:  79


--------------------------------------------------
Correlation: 0.5593801736831665
Model A Feature:  855


Model B Feature:  31825


--------------------------------------------------
Correlation: 0.8912466764450073
Model A Feature:  23732


Model B Feature:  21586


--------------------------------------------------
Correlation: 0.6480451822280884
Model A Feature:  25732


Model B Feature:  25713


--------------------------------------------------
Correlation: 0.7742326855659485
Model A Feature:  15598


Model B Feature:  21648


--------------------------------------------------
Correlation: 0.6245279312133789
Model A Feature:  21227


Model B Feature:  25237


--------------------------------------------------
Correlation: 0.5507810115814209
Model A Feature:  31525


Model B Feature:  10911


--------------------------------------------------
Correlation: 0.7045325040817261
Model A Feature:  12889


Model B Feature:  15027


--------------------------------------------------
Correlation: 0.4298824667930603
Model A Feature:  3612


Model B Feature:  4885


--------------------------------------------------
Correlation: 0.5039418935775757
Model A Feature:  27801


Model B Feature:  16666


--------------------------------------------------
Correlation: 0.8092571496963501
Model A Feature:  17172


Model B Feature:  28993


--------------------------------------------------
Correlation: 0.9164037108421326
Model A Feature:  2267


Model B Feature:  18812


--------------------------------------------------


## which feats which keywords

In [None]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [None]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_A)

16

In [None]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

12

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  field | Count:  3
keyword:  moon | Count:  3
keyword:  branch | Count:  1
keyword:  fire | Count:  1
keyword:  sand | Count:  1
keyword:  sun | Count:  1
keyword:  ice | Count:  1
keyword:  stone | Count:  1
keyword:  heat | Count:  1
keyword:  hill | Count:  1
keyword:  tree | Count:  1
keyword:  air | Count:  1


In [None]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_B)

18

In [None]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

12

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  moon | Count:  4
keyword:  fire | Count:  3
keyword:  field | Count:  2
keyword:  branch | Count:  1
keyword:  ice | Count:  1
keyword:  sand | Count:  1
keyword:  earth | Count:  1
keyword:  sun | Count:  1
keyword:  stone | Count:  1
keyword:  heat | Count:  1
keyword:  tree | Count:  1
keyword:  air | Count:  1


# emotions

In [None]:
new_keywords = [
    "joy", "glee", "pride", "grief", "fear", "hope", "love", "hate", "pain", "shame",
    "bliss", "rage", "calm", "shock", "dread", "guilt", "peace", "trust", "scorn", "doubt",
    "hurt", "wrath", "laugh", "cry", "smile", "frown", "gasp", "blush", "sigh", "grin",
    "woe", "spite", "envy", "glow", "thrill", "mirth", "bored", "cheer", "charm", "grace",
    "shy", "brave", "proud", "glad", "mad", "sad", "tense", "free", "kind"
]

In [None]:
len(new_keywords)

49

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

103
58


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.41379310344827586
24


0.3768650621561141

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


24

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.828463823074054

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.1689040244492055


0.0

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

14

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.9196192159010783

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.22858430147933473


0.0

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

12

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7929500630713265

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.2618477119143779


0.003

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

11

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.8418264444155913

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.25982585464713015


0.002

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.8956789970397949
Model A Feature:  16486


Model B Feature:  20096


--------------------------------------------------
Correlation: 0.4450533986091614
Model A Feature:  28794


Model B Feature:  19972


--------------------------------------------------
Correlation: 0.6877174973487854
Model A Feature:  31791


Model B Feature:  5136


--------------------------------------------------
Correlation: 0.5381913185119629
Model A Feature:  13505


Model B Feature:  675


--------------------------------------------------
Correlation: 0.9568026661872864
Model A Feature:  16948


Model B Feature:  5160


--------------------------------------------------
Correlation: 0.6267017722129822
Model A Feature:  8600


Model B Feature:  18220


--------------------------------------------------
Correlation: 0.8715076446533203
Model A Feature:  12761


Model B Feature:  6456


--------------------------------------------------
Correlation: 0.45579788088798523
Model A Feature:  24381


Model B Feature:  29631


--------------------------------------------------
Correlation: 0.7396957278251648
Model A Feature:  8552


Model B Feature:  30922


--------------------------------------------------
Correlation: 0.6099590063095093
Model A Feature:  19593


Model B Feature:  1488


--------------------------------------------------
Correlation: 0.7726646661758423
Model A Feature:  31570


Model B Feature:  28918


--------------------------------------------------


## which feats which keywords

In [None]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [None]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_A)

12

In [None]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

9

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  hate | Count:  2
keyword:  calm | Count:  2
keyword:  pain | Count:  2
keyword:  kind | Count:  1
keyword:  free | Count:  1
keyword:  love | Count:  1
keyword:  peace | Count:  1
keyword:  smile | Count:  1
keyword:  joy | Count:  1


In [None]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_B)

12

In [None]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

9

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  hate | Count:  2
keyword:  calm | Count:  2
keyword:  pain | Count:  2
keyword:  kind | Count:  1
keyword:  free | Count:  1
keyword:  love | Count:  1
keyword:  peace | Count:  1
keyword:  smile | Count:  1
keyword:  joy | Count:  1


# time v1

In [None]:
new_keywords = [
    "day", "night", "week", "month", "year", "hour", "minute", "second", "now", "soon",
    "later", "early", "late", "morning", "evening", "noon", "midnight", "dawn", "dusk", "past",
    "present", "future", "before", "after", "yesterday", "today", "tomorrow", "next", "previous", "soon",
    "instant", "era", "age", "decade", "century", "millennium",
    "moment", "pause", "wait", "begin", "start", "end", "finish", "stop", "continue",
    "forever", "constant", "frequent",
    "occasion", "season", "spring", "summer", "autumn", "fall", "winter", "anniversary", "deadline", "schedule",
    "calendar", "clock", "duration", "interval", "epoch", "generation", "period", "cycle", "timespan",
    "shift", "quarter", "term", "phase", "lifetime", "century", "minute", "timeline", "delay",
    "prompt", "timely", "recurrent", "daily", "weekly", "monthly", "yearly", "annual", "biweekly", "timeframe"
]

In [None]:
len(new_keywords)

86

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

1170
1144


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.1993006993006993
228


0.4734032724534611

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


228

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5895674735517556

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.051139811373290635


0.0

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

130

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7530502219587717

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.07068091649086641


0.0

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

103

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7576388455843754

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.08107783157857519


0.0

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

78

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.8075217758038356

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.09237787870316988


0.0

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.723274290561676
Model A Feature:  5757


Model B Feature:  30722


--------------------------------------------------
Correlation: 0.4727863073348999
Model A Feature:  32471


Model B Feature:  16386


--------------------------------------------------
Correlation: 0.48114705085754395
Model A Feature:  27324


Model B Feature:  2051


--------------------------------------------------
Correlation: 0.437768816947937
Model A Feature:  3715


Model B Feature:  17


--------------------------------------------------
Correlation: 0.8092077374458313
Model A Feature:  15325


Model B Feature:  28


--------------------------------------------------
Correlation: 0.8611755967140198
Model A Feature:  9657


Model B Feature:  28700


--------------------------------------------------
Correlation: 0.4814339876174927
Model A Feature:  26193


Model B Feature:  6177


--------------------------------------------------
Correlation: 0.5083337426185608
Model A Feature:  28819


Model B Feature:  18472


--------------------------------------------------
Correlation: 0.5857028961181641
Model A Feature:  10502


Model B Feature:  8234


--------------------------------------------------
Correlation: 0.4658796489238739
Model A Feature:  2400


Model B Feature:  8268


--------------------------------------------------
Correlation: 0.49525851011276245
Model A Feature:  32067


Model B Feature:  28752


--------------------------------------------------
Correlation: 0.6927803754806519
Model A Feature:  8539


Model B Feature:  14456


--------------------------------------------------
Correlation: 0.4471716582775116
Model A Feature:  32139


Model B Feature:  18559


--------------------------------------------------
Correlation: 0.7301828861236572
Model A Feature:  13803


Model B Feature:  130


--------------------------------------------------
Correlation: 0.46366921067237854
Model A Feature:  12971


Model B Feature:  28811


--------------------------------------------------
Correlation: 0.7555509805679321
Model A Feature:  30956


Model B Feature:  10398


--------------------------------------------------
Correlation: 0.8872765302658081
Model A Feature:  633


Model B Feature:  8350


--------------------------------------------------
Correlation: 0.5471289157867432
Model A Feature:  28732


Model B Feature:  8354


--------------------------------------------------
Correlation: 0.6940833330154419
Model A Feature:  11538


Model B Feature:  28843


--------------------------------------------------
Correlation: 0.9309592843055725
Model A Feature:  24703


Model B Feature:  16556


--------------------------------------------------
Correlation: 0.4071653187274933
Model A Feature:  23528


Model B Feature:  175


--------------------------------------------------
Correlation: 0.551183819770813
Model A Feature:  7176


Model B Feature:  22715


--------------------------------------------------
Correlation: 0.7253245115280151
Model A Feature:  5511


Model B Feature:  4288


--------------------------------------------------
Correlation: 0.5258380174636841
Model A Feature:  15927


Model B Feature:  28867


--------------------------------------------------
Correlation: 0.8712546825408936
Model A Feature:  29069


Model B Feature:  223


--------------------------------------------------
Correlation: 0.7943718433380127
Model A Feature:  19109


Model B Feature:  28943


--------------------------------------------------
Correlation: 0.4307520091533661
Model A Feature:  4309


Model B Feature:  18766


--------------------------------------------------
Correlation: 0.5393396615982056
Model A Feature:  15168


Model B Feature:  335


--------------------------------------------------
Correlation: 0.5702798366546631
Model A Feature:  4464


Model B Feature:  12699


--------------------------------------------------
Correlation: 0.9350628852844238
Model A Feature:  1458


Model B Feature:  2481


--------------------------------------------------
Correlation: 0.6533108949661255
Model A Feature:  5400


Model B Feature:  29120


--------------------------------------------------
Correlation: 0.729668378829956
Model A Feature:  541


Model B Feature:  22980


--------------------------------------------------
Correlation: 0.4513002932071686
Model A Feature:  31487


Model B Feature:  4555


--------------------------------------------------
Correlation: 0.4617210626602173
Model A Feature:  6221


Model B Feature:  2520


--------------------------------------------------
Correlation: 0.898678719997406
Model A Feature:  20144


Model B Feature:  8690


--------------------------------------------------
Correlation: 0.532450258731842
Model A Feature:  3350


Model B Feature:  27141


--------------------------------------------------
Correlation: 0.6773052215576172
Model A Feature:  14439


Model B Feature:  6683


--------------------------------------------------
Correlation: 0.7813739776611328
Model A Feature:  3231


Model B Feature:  31260


--------------------------------------------------
Correlation: 0.7522636651992798
Model A Feature:  10880


Model B Feature:  31299


--------------------------------------------------
Correlation: 0.678644061088562
Model A Feature:  30103


Model B Feature:  16991


--------------------------------------------------
Correlation: 0.4091850221157074
Model A Feature:  32522


Model B Feature:  10853


--------------------------------------------------
Correlation: 0.8153980374336243
Model A Feature:  9389


Model B Feature:  21093


--------------------------------------------------
Correlation: 0.7235473990440369
Model A Feature:  18211


Model B Feature:  6809


--------------------------------------------------
Correlation: 0.677230954170227
Model A Feature:  30874


Model B Feature:  17054


--------------------------------------------------
Correlation: 0.7093968987464905
Model A Feature:  16511


Model B Feature:  23206


--------------------------------------------------
Correlation: 0.9471890330314636
Model A Feature:  27377


Model B Feature:  19127


--------------------------------------------------
Correlation: 0.4278598129749298
Model A Feature:  16611


Model B Feature:  25272


--------------------------------------------------
Correlation: 0.5853958129882812
Model A Feature:  8446


Model B Feature:  29433


--------------------------------------------------
Correlation: 0.9835914373397827
Model A Feature:  6674


Model B Feature:  6928


--------------------------------------------------
Correlation: 0.8954758048057556
Model A Feature:  8560


Model B Feature:  13083


--------------------------------------------------
Correlation: 0.7918717265129089
Model A Feature:  30230


Model B Feature:  31572


--------------------------------------------------
Correlation: 0.592656135559082
Model A Feature:  21145


Model B Feature:  9042


--------------------------------------------------
Correlation: 0.46408432722091675
Model A Feature:  4576


Model B Feature:  29597


--------------------------------------------------
Correlation: 0.5748183131217957
Model A Feature:  13259


Model B Feature:  15303


--------------------------------------------------
Correlation: 0.6871882677078247
Model A Feature:  13782


Model B Feature:  17379


--------------------------------------------------
Correlation: 0.4086652994155884
Model A Feature:  3664


Model B Feature:  19444


--------------------------------------------------
Correlation: 0.635159969329834
Model A Feature:  31352


Model B Feature:  15397


--------------------------------------------------
Correlation: 0.4812500476837158
Model A Feature:  2455


Model B Feature:  25645


--------------------------------------------------
Correlation: 0.9176865816116333
Model A Feature:  12913


Model B Feature:  25664


--------------------------------------------------
Correlation: 0.8095675110816956
Model A Feature:  25308


Model B Feature:  31809


--------------------------------------------------
Correlation: 0.9071242809295654
Model A Feature:  20981


Model B Feature:  13396


--------------------------------------------------
Correlation: 0.8221204280853271
Model A Feature:  7972


Model B Feature:  7270


--------------------------------------------------
Correlation: 0.8672628402709961
Model A Feature:  26609


Model B Feature:  29809


--------------------------------------------------
Correlation: 0.44830188155174255
Model A Feature:  23633


Model B Feature:  21697


--------------------------------------------------
Correlation: 0.521644651889801
Model A Feature:  28011


Model B Feature:  32026


--------------------------------------------------
Correlation: 0.7326370477676392
Model A Feature:  26972


Model B Feature:  25901


--------------------------------------------------
Correlation: 0.6623398065567017
Model A Feature:  12834


Model B Feature:  13764


--------------------------------------------------
Correlation: 0.6369298100471497
Model A Feature:  23531


Model B Feature:  19907


--------------------------------------------------
Correlation: 0.6535170078277588
Model A Feature:  23881


Model B Feature:  7686


--------------------------------------------------
Correlation: 0.7039303183555603
Model A Feature:  11809


Model B Feature:  22052


--------------------------------------------------
Correlation: 0.4377000033855438
Model A Feature:  4536


Model B Feature:  26173


--------------------------------------------------
Correlation: 0.9393486976623535
Model A Feature:  14410


Model B Feature:  32433


--------------------------------------------------
Correlation: 0.6025209426879883
Model A Feature:  16563


Model B Feature:  5861


--------------------------------------------------
Correlation: 0.4669067859649658
Model A Feature:  9120


Model B Feature:  3826


--------------------------------------------------
Correlation: 0.5001254677772522
Model A Feature:  27314


Model B Feature:  5917


--------------------------------------------------
Correlation: 0.7295408844947815
Model A Feature:  16415


Model B Feature:  3905


--------------------------------------------------
Correlation: 0.464631587266922
Model A Feature:  10982


Model B Feature:  6081


--------------------------------------------------
Correlation: 0.7044895887374878
Model A Feature:  26485


Model B Feature:  16330


--------------------------------------------------


## which feats which keywords

In [None]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [None]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_A)

92

In [None]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

42

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  after | Count:  12
keyword:  today | Count:  7
keyword:  week | Count:  6
keyword:  winter | Count:  5
keyword:  month | Count:  4
keyword:  before | Count:  4
keyword:  year | Count:  4
keyword:  term | Count:  3
keyword:  later | Count:  3
keyword:  moment | Count:  3
keyword:  summer | Count:  3
keyword:  morning | Count:  2
keyword:  dawn | Count:  2
keyword:  now | Count:  2
keyword:  end | Count:  2
keyword:  second | Count:  2
keyword:  day | Count:  2
keyword:  spring | Count:  2
keyword:  hour | Count:  1
keyword:  yesterday | Count:  1
keyword:  annual | Count:  1
keyword:  monthly | Count:  1
keyword:  soon | Count:  1
keyword:  start | Count:  1
keyword:  fall | Count:  1
keyword:  pause | Count:  1
keyword:  frequent | Count:  1
keyword:  season | Count:  1
keyword:  tomorrow | Count:  1
keyword:  past | Count:  1
keyword:  future | Count:  1
keyword:  next | Count:  1
keyword:  schedule | Count:  1
keyword:  continue | Count:  1
keyword:  begin | Count:  1
keywo

In [None]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_B)

95

In [None]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

44

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  after | Count:  10
keyword:  today | Count:  10
keyword:  winter | Count:  10
keyword:  week | Count:  4
keyword:  yesterday | Count:  3
keyword:  night | Count:  3
keyword:  term | Count:  3
keyword:  year | Count:  3
keyword:  day | Count:  3
keyword:  month | Count:  2
keyword:  now | Count:  2
keyword:  end | Count:  2
keyword:  second | Count:  2
keyword:  next | Count:  2
keyword:  moment | Count:  2
keyword:  schedule | Count:  2
keyword:  before | Count:  2
keyword:  quarter | Count:  2
keyword:  summer | Count:  2
keyword:  spring | Count:  2
keyword:  hour | Count:  1
keyword:  noon | Count:  1
keyword:  annual | Count:  1
keyword:  monthly | Count:  1
keyword:  weekly | Count:  1
keyword:  soon | Count:  1
keyword:  start | Count:  1
keyword:  fall | Count:  1
keyword:  dawn | Count:  1
keyword:  pause | Count:  1
keyword:  frequent | Count:  1
keyword:  season | Count:  1
keyword:  tomorrow | Count:  1
keyword:  past | Count:  1
keyword:  continue | Count:  1
keyw

# time v2- rmv after

In [None]:
new_keywords = [
    "day", "night", "week", "month", "year", "hour", "minute", "second", "now", "soon",
    "later", "early", "late", "morning", "evening", "noon", "midnight", "dawn", "dusk", "past",
    "present", "future", "before", "yesterday", "today", "tomorrow", "next", "previous", "soon",
    "instant", "era", "age", "decade", "century", "millennium",
    "moment", "pause", "wait", "begin", "start", "end", "finish", "stop", "continue",
    "forever", "constant", "frequent",
    "occasion", "season", "spring", "summer", "autumn", "fall", "winter", "anniversary", "deadline", "schedule",
    "calendar", "clock", "duration", "interval", "epoch", "generation", "period", "cycle", "timespan",
    "shift", "quarter", "term", "phase", "lifetime", "century", "minute", "timeline", "delay",
    "prompt", "timely", "recurrent", "daily", "weekly", "monthly", "yearly", "annual", "biweekly", "timeframe"
]

In [None]:
len(new_keywords)

85

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

989
1047


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.19579751671442217
205


0.42298319067987467

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


205

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5775936047217759

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.056327765391360224


0.0

# calendar

In [None]:
new_keywords = [
    "day", "night", "week", "month", "year", "hour", "minute", "second",
    "morning", "evening", "noon", "midnight", "dawn", "dusk",
    "yesterday", "today", "tomorrow",
    "decade", "century", "millennium",
    "season", "spring", "summer", "autumn", "fall", "winter",
    "calendar", "clock",
    "century", "minute",
    "daily", "weekly", "monthly", "yearly", "annual", "biweekly", "timeframe"
]

In [None]:
len(new_keywords)

37

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

586
776


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.16237113402061856
126


0.42687703386313985

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


126

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.649097380971248

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.07117647816549452


0.0

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

65

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.8068892095177752

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.09640523623676268


0.0

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

49

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7600860575602976

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.11582217683274552


0.0

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

34

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.5858928901963751

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.14083280029830683


0.001

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.7232800126075745
Model A Feature:  5757


Model B Feature:  30722


--------------------------------------------------
Correlation: 0.42780208587646484
Model A Feature:  10502


Model B Feature:  17


--------------------------------------------------
Correlation: 0.8092209100723267
Model A Feature:  15325


Model B Feature:  28


--------------------------------------------------
Correlation: 0.8611764311790466
Model A Feature:  9657


Model B Feature:  28700


--------------------------------------------------
Correlation: 0.4814237654209137
Model A Feature:  26193


Model B Feature:  6177


--------------------------------------------------
Correlation: 0.5083207488059998
Model A Feature:  28819


Model B Feature:  18472


--------------------------------------------------
Correlation: 0.4658823609352112
Model A Feature:  2400


Model B Feature:  8268


--------------------------------------------------
Correlation: 0.6927764415740967
Model A Feature:  8539


Model B Feature:  14456


--------------------------------------------------
Correlation: 0.4157389998435974
Model A Feature:  23241


Model B Feature:  8335


--------------------------------------------------
Correlation: 0.5471159815788269
Model A Feature:  28732


Model B Feature:  8354


--------------------------------------------------
Correlation: 0.6940757632255554
Model A Feature:  11538


Model B Feature:  28843


--------------------------------------------------
Correlation: 0.9309540390968323
Model A Feature:  24703


Model B Feature:  16556


--------------------------------------------------
Correlation: 0.40715301036834717
Model A Feature:  23528


Model B Feature:  175


--------------------------------------------------
Correlation: 0.6782609224319458
Model A Feature:  13803


Model B Feature:  24770


--------------------------------------------------
Correlation: 0.8712375164031982
Model A Feature:  29069


Model B Feature:  223


--------------------------------------------------
Correlation: 0.5393164157867432
Model A Feature:  15168


Model B Feature:  335


--------------------------------------------------
Correlation: 0.9350656867027283
Model A Feature:  1458


Model B Feature:  2481


--------------------------------------------------
Correlation: 0.7296679615974426
Model A Feature:  541


Model B Feature:  22980


--------------------------------------------------
Correlation: 0.4513265788555145
Model A Feature:  31487


Model B Feature:  4555


--------------------------------------------------
Correlation: 0.6786489486694336
Model A Feature:  30103


Model B Feature:  16991


--------------------------------------------------
Correlation: 0.7093779444694519
Model A Feature:  16511


Model B Feature:  23206


--------------------------------------------------
Correlation: 0.9469823241233826
Model A Feature:  10496


Model B Feature:  19127


--------------------------------------------------
Correlation: 0.5854278206825256
Model A Feature:  8446


Model B Feature:  29433


--------------------------------------------------
Correlation: 0.4640791118144989
Model A Feature:  4576


Model B Feature:  29597


--------------------------------------------------
Correlation: 0.5748079419136047
Model A Feature:  13259


Model B Feature:  15303


--------------------------------------------------
Correlation: 0.48125159740448
Model A Feature:  2455


Model B Feature:  25645


--------------------------------------------------
Correlation: 0.9177352786064148
Model A Feature:  12913


Model B Feature:  25664


--------------------------------------------------
Correlation: 0.8221150040626526
Model A Feature:  7972


Model B Feature:  7270


--------------------------------------------------
Correlation: 0.653498649597168
Model A Feature:  23881


Model B Feature:  7686


--------------------------------------------------
Correlation: 0.43771275877952576
Model A Feature:  4536


Model B Feature:  26173


--------------------------------------------------
Correlation: 0.4669012129306793
Model A Feature:  9120


Model B Feature:  3826


--------------------------------------------------
Correlation: 0.5001319050788879
Model A Feature:  27314


Model B Feature:  5917


--------------------------------------------------
Correlation: 0.7295692563056946
Model A Feature:  16415


Model B Feature:  3905


--------------------------------------------------
Correlation: 0.46462079882621765
Model A Feature:  10982


Model B Feature:  6081


--------------------------------------------------


## which feats which keywords

In [None]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [None]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_A)

44

In [None]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

18

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  today | Count:  9
keyword:  week | Count:  5
keyword:  month | Count:  4
keyword:  winter | Count:  4
keyword:  year | Count:  4
keyword:  summer | Count:  3
keyword:  dawn | Count:  2
keyword:  day | Count:  2
keyword:  spring | Count:  2
keyword:  hour | Count:  1
keyword:  yesterday | Count:  1
keyword:  morning | Count:  1
keyword:  annual | Count:  1
keyword:  monthly | Count:  1
keyword:  fall | Count:  1
keyword:  second | Count:  1
keyword:  season | Count:  1
keyword:  tomorrow | Count:  1


In [None]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [None]:
len(top_tok_list_B)

47

In [None]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

20

In [None]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

keyword:  today | Count:  9
keyword:  winter | Count:  6
keyword:  week | Count:  4
keyword:  yesterday | Count:  3
keyword:  night | Count:  3
keyword:  year | Count:  3
keyword:  day | Count:  3
keyword:  month | Count:  2
keyword:  summer | Count:  2
keyword:  spring | Count:  2
keyword:  hour | Count:  1
keyword:  noon | Count:  1
keyword:  annual | Count:  1
keyword:  monthly | Count:  1
keyword:  weekly | Count:  1
keyword:  fall | Count:  1
keyword:  dawn | Count:  1
keyword:  second | Count:  1
keyword:  season | Count:  1
keyword:  tomorrow | Count:  1


# month names

In [53]:
new_keywords = [
    "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"
]

In [54]:
len(new_keywords)

12

In [55]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [56]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

137
118


## run 1-1

In [57]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.2711864406779661
32


0.32932982871593053

In [58]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


32

In [59]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7190808117142935

In [60]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.14883889364369635


0.0

In [87]:
paired_rsa = representational_similarity_analysis(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_rsa

0.7610013175230566

In [88]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          representational_similarity_analysis, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_rsa)

0.009332542819499341


0.0

## filter out low corr

In [61]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [62]:
len(rmvLow_corr_inds_A)

22

In [63]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7828235112103981

In [64]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.18023171507310531


0.0

## filter out low corr: 0.3

In [65]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [66]:
len(rmvLow_corr_inds_A)

12

In [67]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7519752343035806

In [68]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.24326031643717308


0.008

## filter out low corr: 0.4

In [69]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [70]:
len(rmvLow_corr_inds_A)

10

In [71]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.903605942724504

In [72]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.2714249379988349


0.003

## interpret

In [76]:
samp_m = 5

In [77]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4818212389945984
Model A Feature:  21658


Model B Feature:  19975


--------------------------------------------------
Correlation: 0.4735376238822937
Model A Feature:  26209


Model B Feature:  4626


--------------------------------------------------
Correlation: 0.8121873736381531
Model A Feature:  18166


Model B Feature:  30931


--------------------------------------------------
Correlation: 0.5839450359344482
Model A Feature:  22164


Model B Feature:  30507


--------------------------------------------------
Correlation: 0.9022921919822693
Model A Feature:  28889


Model B Feature:  7480


--------------------------------------------------
Correlation: 0.682283341884613
Model A Feature:  24913


Model B Feature:  21827


--------------------------------------------------
Correlation: 0.5116301774978638
Model A Feature:  27278


Model B Feature:  23884


--------------------------------------------------
Correlation: 0.5058205723762512
Model A Feature:  9639


Model B Feature:  23889


--------------------------------------------------
Correlation: 0.6141257882118225
Model A Feature:  29710


Model B Feature:  5007


--------------------------------------------------
Correlation: 0.5249927043914795
Model A Feature:  9477


Model B Feature:  2541


--------------------------------------------------


## which feats which keywords

In [78]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [79]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [80]:
len(top_tok_list_A)

0

In [81]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

0

In [82]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

In [83]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [84]:
len(top_tok_list_B)

0

In [85]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

0

In [86]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

# countries

In [89]:
keywords = [
    "USA", "Canada", "Brazil", "Mexico", "Germany", "France", "Italy", "Spain", "UK", "Australia",
    "China", "Japan", "India", "Russia", "Korea", "Argentina", "Egypt", "Iran", "Turkey"
]


In [90]:
len(new_keywords)

12

In [91]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [92]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

137
118


## run 1-1

In [93]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.2711864406779661
32


0.32932982871593053

In [94]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


32

In [95]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7190808117142935

In [96]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.14604309513131855


0.0

In [97]:
paired_rsa = representational_similarity_analysis(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_rsa

0.10204845888625988

In [98]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          representational_similarity_analysis, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_rsa)

-0.0009296544663877683


0.026

## filter out low corr

In [99]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [100]:
len(rmvLow_corr_inds_A)

22

In [101]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7828235112103981

In [102]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.17297183271575978


0.0

## filter out low corr: 0.3

In [103]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [104]:
len(rmvLow_corr_inds_A)

12

In [105]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.7519752343035806

In [106]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.24873512294915026


0.011

## filter out low corr: 0.4

In [107]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [108]:
len(rmvLow_corr_inds_A)

10

In [109]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.903605942724504

In [110]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.2665228797517477


0.0

## interpret

In [111]:
samp_m = 5

In [112]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.4818212389945984
Model A Feature:  21658


Model B Feature:  19975


--------------------------------------------------
Correlation: 0.4735376238822937
Model A Feature:  26209


Model B Feature:  4626


--------------------------------------------------
Correlation: 0.8121873736381531
Model A Feature:  18166


Model B Feature:  30931


--------------------------------------------------
Correlation: 0.5839450359344482
Model A Feature:  22164


Model B Feature:  30507


--------------------------------------------------
Correlation: 0.9022921919822693
Model A Feature:  28889


Model B Feature:  7480


--------------------------------------------------
Correlation: 0.682283341884613
Model A Feature:  24913


Model B Feature:  21827


--------------------------------------------------
Correlation: 0.5116301774978638
Model A Feature:  27278


Model B Feature:  23884


--------------------------------------------------
Correlation: 0.5058205723762512
Model A Feature:  9639


Model B Feature:  23889


--------------------------------------------------
Correlation: 0.6141257882118225
Model A Feature:  29710


Model B Feature:  5007


--------------------------------------------------
Correlation: 0.5249927043914795
Model A Feature:  9477


Model B Feature:  2541


--------------------------------------------------


## which feats which keywords

In [113]:
top_toks_afterFilt_A = []
top_toks_afterFilt_B = []
for feat_ind_A, feat_ind_B in zip(original_A_indices, original_B_indices):
    top_toks_afterFilt_A.append( fList_model_A_seqs[feat_ind_A] )
    top_toks_afterFilt_B.append( fList_model_B_seqs[feat_ind_B] )

In [114]:
top_tok_list_A = []
for top_seq_list_feat in top_toks_afterFilt_A:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_A.append(top_tok)
            keywords_in_feat.add(top_tok)

In [115]:
len(top_tok_list_A)

0

In [116]:
sorted_kw_counts = Counter(top_tok_list_A).most_common()
len(sorted_kw_counts) # num unique keywords

0

In [117]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

In [118]:
top_tok_list_B = []
for top_seq_list_feat in top_toks_afterFilt_B:
    # each feature has top 5 samps in top_seq_list_feat
    """
    if a feature contains “king” 3 times, we should only record that it contains “king”.
    this is bc we’re trying to find feature sim based on which keywords they activate on,
    but it doesn’t matter how many times that keyword appears in its top 5.
    """
    keywords_in_feat = set()
    for top_seq in top_seq_list_feat:
        # top_seq is one seq and its top token
        top_tok = top_seq[1].replace(' ', '').lower()
        if top_tok in new_keywords and top_tok not in keywords_in_feat:
            top_tok_list_B.append(top_tok)
            keywords_in_feat.add(top_tok)

In [119]:
len(top_tok_list_B)

0

In [120]:
sorted_kw_counts = Counter(top_tok_list_B).most_common()
len(sorted_kw_counts) # num unique keywords

0

In [121]:
for rankID in range(len(sorted_kw_counts)):
    feat_ID = sorted_kw_counts[rankID][0]
    print("keyword: ", feat_ID, "| Count: ", sorted_kw_counts[rankID][1])

# random keywords

In [None]:
new_keywords = [
    "apple", "bicycle", "cloud", "dog", "elephant", "fountain", "guitar", "honey",
    "iceberg", "jelly", "kangaroo", "laptop", "mountain", "notebook", "ocean",
    "piano", "quartz", "river", "satellite", "tiger", "umbrella", "volcano",
    "whale", "xylophone", "yogurt", "zebra", "balloon", "candle", "desert",
    "engine", "forest", "glove", "hat", "insect", "jungle", "key", "lamp",
    "microscope", "nest", "octopus", "penguin", "quill", "robot", "sandwich",
    "tree", "unicorn", "vase", "window", "yarn", "zipper"
]

In [None]:
len(new_keywords)

50

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

66
51


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.27450980392156865
14


0.3523550065709095

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 66])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 51])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


14

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.6114723277327786

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.21910837397910124


0.01

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

10

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.39894195656995207

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.2702163077336621


0.237

## filter out low corr: 0.3

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.3:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

7

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.4787777650638388

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.3320168294884029


0.252

## filter out low corr: 0.4

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.4:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

6

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.3458088985831677

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.4126546909202855


0.647

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5891282558441162
Model A Feature:  1234


Model B Feature:  25986


--------------------------------------------------
Correlation: 0.7508215308189392
Model A Feature:  29817


Model B Feature:  27292


--------------------------------------------------
Correlation: 0.789036750793457
Model A Feature:  16342


Model B Feature:  21922


--------------------------------------------------
Correlation: 0.8894489407539368
Model A Feature:  18351


Model B Feature:  8889


--------------------------------------------------
Correlation: 0.4566425085067749
Model A Feature:  13967


Model B Feature:  7105


--------------------------------------------------
Correlation: 0.481752872467041
Model A Feature:  24857


Model B Feature:  30668


--------------------------------------------------


# rand sel features

In [None]:
num_feats = 100
mixed_modA_feats = np.random.choice(range(weight_matrix_np.shape[0]), size=num_feats, replace=False).tolist()
mixed_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

100
100


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.21
21


0.3272296492755413

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 100])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 100])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


21

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.003018856567061415

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.18789391569800096


0.99

## interpret

In [None]:
num_feats = 5
for feature_idx_A, feature_idx_B in zip(original_A_indices[:num_feats],
                                                      original_B_indices[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.028716091066598892
Model A Feature:  19532


Model B Feature:  5969


--------------------------------------------------
Correlation: 0.028716091066598892
Model A Feature:  17014


Model B Feature:  31248


--------------------------------------------------
Correlation: 0.028716091066598892
Model A Feature:  8016


Model B Feature:  28269


--------------------------------------------------
Correlation: 0.028716091066598892
Model A Feature:  22006


Model B Feature:  29280


--------------------------------------------------
Correlation: 0.028716091066598892
Model A Feature:  11814


Model B Feature:  29101


--------------------------------------------------


# unrelated

In [None]:
new_keywords = [
    "apple", "bicycle", "cloud", "dog", "fountain", "guitar",
    "iceberg", "laptop", "mountain", "notebook",
    "quartz", "satellite", "umbrella",
    "xylophone", "yogurt", "balloon", "candle", "desert",
    "engine", "glove", "key", "lamp",
    "microscope", "nest", "quill", "robot", "sandwich",
    "unicorn", "vase", "window", "yarn", "zipper"
]

In [None]:
len(new_keywords)

32

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

49
34


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.2647058823529412
9


0.33283989898422184

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 49])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 34])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


9

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.034821249017866145

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.31352187610388993


0.97

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

5

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.9477397815324711

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.42704869315906735


0.026

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5891275405883789
Model A Feature:  1234


Model B Feature:  25986


--------------------------------------------------
Correlation: 0.3224208652973175
Model A Feature:  1158


Model B Feature:  22794


--------------------------------------------------
Correlation: 0.7508212327957153
Model A Feature:  29817


Model B Feature:  27292


--------------------------------------------------
Correlation: 0.4566418528556824
Model A Feature:  13967


Model B Feature:  7105


--------------------------------------------------
Correlation: 0.4817536473274231
Model A Feature:  24857


Model B Feature:  30668


--------------------------------------------------


# unrelated v2

In [None]:
new_keywords = [
    "bicycle", "cloud", "dog", "fountain", "guitar",
    "iceberg", "laptop", "mountain", "notebook",
    "quartz", "satellite", "umbrella",
    "xylophone", "yogurt", "balloon", "candle", "desert",
    "engine", "glove", "key", "lamp",
    "microscope", "nest", "quill", "robot", "sandwich",
    "unicorn", "vase", "window", "yarn", "zipper"
]

In [None]:
len(new_keywords)

31

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

44
31


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.22580645161290322
7


0.3108127307026617

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


7

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.3760137001257752

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.3690036529705287


0.45

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

3

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.9871874734192366

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.6477896099220646


0.312

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5891284942626953
Model A Feature:  1234


Model B Feature:  25986


--------------------------------------------------
Correlation: 0.32242056727409363
Model A Feature:  1158


Model B Feature:  22794


--------------------------------------------------
Correlation: 0.7508208155632019
Model A Feature:  29817


Model B Feature:  27292


--------------------------------------------------


# unrelated v3

In [None]:
new_keywords = [
    "bicycle", "dog", "fountain", "guitar",
    "iceberg", "laptop", "mountain", "notebook",
    "quartz", "satellite", "umbrella",
    "xylophone", "yogurt", "balloon", "candle", "desert",
    "engine", "glove", "key", "lamp",
    "microscope", "nest", "quill", "robot",
    "unicorn", "vase", "window", "yarn", "zipper"
]

In [None]:
len(new_keywords)

29

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

3
5


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.2
1


0.30409437492489816

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 3])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 5])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


1

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.24628084897994995


IndexError: list index out of range

# unrelated v4

In [None]:
new_keywords = [
    "apple", "one", "man", "ocean", "sad"
]

In [None]:
len(new_keywords)

5

In [None]:
mixed_modA_feats = set()
mixed_modB_feats = set()
for kw in new_keywords:
    modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)
    modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)
    mixed_modA_feats.update(modA_feats)
    mixed_modB_feats.update(modB_feats)

mixed_modA_feats = list(mixed_modA_feats)
mixed_modB_feats = list(mixed_modB_feats)

In [None]:
print(len(mixed_modA_feats))
print(len(mixed_modB_feats))

146
76


## run 1-1

In [None]:
subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],
                                               reshaped_activations_B[:, mixed_modB_feats])

num_unq_pairs = len(list(set(subset_inds)))
print("% unique: ", num_unq_pairs / len(subset_inds))
print(num_unq_pairs)
sum(subset_vals) / len(subset_vals)

% unique:  0.3684210526315789
28


0.47825957349452536

In [None]:
reshaped_activations_A[:, mixed_modA_feats].shape

torch.Size([30000, 146])

In [None]:
reshaped_activations_B[:, mixed_modB_feats].shape

torch.Size([30000, 76])

In [None]:
subset_sorted_feat_counts = Counter(subset_inds).most_common()
subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]

filt_corr_ind_A = []
filt_corr_ind_B = []
seen = set()
for ind_B, ind_A in enumerate(subset_inds):
    if ind_A in subset_kept_modA_feats:
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
    elif ind_A not in seen:  # only keep one if it's over count X
        seen.add(ind_A)
        filt_corr_ind_A.append(ind_A)
        filt_corr_ind_B.append(ind_B)
num_unq_pairs = len(list(set(filt_corr_ind_A)))
print("% unique: ", num_unq_pairs / len(filt_corr_ind_A))
num_unq_pairs

% unique:  1.0


28

In [None]:
# X_subset = weight_matrix_np[mixed_modA_feats]
# Y_subset = weight_matrix_2[mixed_modB_feats]

# paired_svcca = svcca(X_subset[filt_corr_ind_A], Y_subset[filt_corr_ind_B], "nd")
# paired_svcca

original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]
original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.21546125301779254

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]
# weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(100, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.15143268664138437


0.26

## filter out low corr

In [None]:
rmvLow_corr_inds_A = []
rmvLow_corr_inds_B = []
rmvLow_corr_vals = []

for ind_A, ind_B in zip(filt_corr_ind_A, filt_corr_ind_B):
    val = subset_vals[ind_B]
    if val > 0.2:
        rmvLow_corr_inds_A.append(ind_A)
        rmvLow_corr_inds_B.append(ind_B)
        rmvLow_corr_vals.append(val)

In [None]:
len(rmvLow_corr_inds_A)

18

In [None]:
original_A_indices = [mixed_modA_feats[ind] for ind in rmvLow_corr_inds_A]
original_B_indices = [mixed_modB_feats[i] for i in rmvLow_corr_inds_B]

paired_svcca = svcca(weight_matrix_np[original_A_indices], weight_matrix_2[original_B_indices], "nd")
paired_svcca

0.051789917384082514

In [None]:
X_subset = weight_matrix_np[original_A_indices]
Y_subset = weight_matrix_2[original_B_indices]

all_rand_scores = shuffle_rand(1000, X_subset, Y_subset, Y_subset.shape[0],
                                          svcca, shapereq_bool=True)
print( sum(all_rand_scores) / len(all_rand_scores) )
np.mean(np.array(all_rand_scores) >= paired_svcca)

0.1994542648944047


0.863

## interpret

In [None]:
num_feats = len(rmvLow_corr_inds_A)
for subset_feature_idx_A, subset_feature_idx_B in zip(rmvLow_corr_inds_A[:num_feats],
                                                      rmvLow_corr_inds_B[:num_feats]):
    print(f'Correlation: {subset_vals[subset_feature_idx_B]}')
    feature_idx_A = mixed_modA_feats[subset_feature_idx_A]
    feature_idx_B = mixed_modB_feats[subset_feature_idx_B]
    print('Model A Feature: ', feature_idx_A)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx_A, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('Model B Feature: ', feature_idx_B)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx_B, samp_m, batch_tokens=inputs["input_ids"])
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=inputs["input_ids"])

    print('-'*50)

Correlation: 0.5952286124229431
Model A Feature:  21606


Model B Feature:  18176


--------------------------------------------------
Correlation: 0.6470754742622375
Model A Feature:  13073


Model B Feature:  6542


--------------------------------------------------
Correlation: 0.4130881130695343
Model A Feature:  18446


Model B Feature:  25231


--------------------------------------------------
Correlation: 0.551276683807373
Model A Feature:  26222


Model B Feature:  5395


--------------------------------------------------
Correlation: 0.25650888681411743
Model A Feature:  287


Model B Feature:  13077


--------------------------------------------------
Correlation: 0.6840993165969849
Model A Feature:  10548


Model B Feature:  5531


--------------------------------------------------
Correlation: 0.4871455430984497
Model A Feature:  3724


Model B Feature:  10651


--------------------------------------------------
Correlation: 0.4417263865470886
Model A Feature:  2023


Model B Feature:  29597


--------------------------------------------------
Correlation: 0.6651054620742798
Model A Feature:  26187


Model B Feature:  28832


--------------------------------------------------
Correlation: 0.27913549542427063
Model A Feature:  30330


Model B Feature:  25018


--------------------------------------------------
Correlation: 0.45931100845336914
Model A Feature:  21196


Model B Feature:  32061


--------------------------------------------------
Correlation: 0.6085864901542664
Model A Feature:  5764


Model B Feature:  13632


--------------------------------------------------
Correlation: 0.3523541986942291
Model A Feature:  5732


Model B Feature:  3018


--------------------------------------------------
Correlation: 0.48175281286239624
Model A Feature:  24857


Model B Feature:  30668


--------------------------------------------------
Correlation: 0.9031506180763245
Model A Feature:  2457


Model B Feature:  3196


--------------------------------------------------
Correlation: 0.7025443315505981
Model A Feature:  8783


Model B Feature:  28378


--------------------------------------------------
Correlation: 0.43087121844291687
Model A Feature:  6802


Model B Feature:  27878


--------------------------------------------------
Correlation: 0.42479127645492554
Model A Feature:  16014


Model B Feature:  9470


--------------------------------------------------
