# get data

In [None]:
%%capture
!pip install datasets

In [None]:
from datasets import load_dataset
dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


In [None]:
import gc

# Function to get next batch
def get_next_batch(dataset_iter, batch_size=100):
    batch = []
    for _ in range(batch_size):
        try:
            sample = next(dataset_iter)
            batch.append(sample['text'])
        except StopIteration:
            break
    return batch

# Create dataset iterator
dataset_iter = iter(dataset)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Tokenizer and model loading
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m").to('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
        

In [None]:
layer_id = 0

In [None]:
accumulated_outputs = None
batch_size = 50
maxseqlen = 100

# Loop through the entire dataset in batches
# while True:
for i in range(1):
    batch = get_next_batch(dataset_iter, batch_size)
    if not batch:
        break  # Stop if there are no more batches

    # Tokenize the batch and move to the device
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=maxseqlen)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = model(**inputs, output_hidden_states=True)
        if accumulated_outputs is None:
            accumulated_outputs = outputs.hidden_states[layer_id]
        else:
            accumulated_outputs = torch.cat((accumulated_outputs, outputs.hidden_states[layer_id]), dim= 0)

    # Clear memory to prevent OOM
    del inputs, outputs
    torch.cuda.empty_cache()  # Only if you're using CUDA
    gc.collect()

In [None]:
# outputs.hidden_states[layer_id].shape

In [None]:
accumulated_outputs.shape

torch.Size([50, 100, 512])

In [None]:
first_dim_reshaped = accumulated_outputs.shape[0] * accumulated_outputs.shape[1]
accumulated_outputs = accumulated_outputs.reshape(first_dim_reshaped, accumulated_outputs.shape[-1]).cpu()

In [None]:
accumulated_outputs.shape

torch.Size([5000, 512])

# small data

In [None]:
import torch
testdata = torch.tensor([[1, 2, 1],
[3, 3, 5],
[2, 1, 2]]).float()

In [None]:
testdata.shape

torch.Size([3, 3])

# corr fns

In [None]:
def batched_correlation(reshaped_activations_A, reshaped_activations_B, batch_size=100):
    # Ensure tensors are on GPU
    if torch.cuda.is_available():
        reshaped_activations_A = reshaped_activations_A.to('cuda')
        reshaped_activations_B = reshaped_activations_B.to('cuda')

    # Normalize columns of A
    mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)
    std_A = reshaped_activations_A.std(dim=0, keepdim=True)
    normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero

    # Normalize columns of B
    mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)
    std_B = reshaped_activations_B.std(dim=0, keepdim=True)
    normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero

    num_batches = (normalized_B.shape[1] + batch_size - 1) // batch_size
    max_values = []
    max_indices = []

    for batch in range(num_batches):
        start = batch * batch_size
        print(start)
        end = min(start + batch_size, normalized_B.shape[1])
        batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, start:end]) / normalized_A.shape[0]
        max_val, max_idx = batch_corr_matrix.max(dim=0)
        max_values.append(max_val)
        max_indices.append(max_idx)  # Adjust indices for the batch offset

        del batch_corr_matrix
        torch.cuda.empty_cache()

    return torch.cat(max_indices), torch.cat(max_values)

## small data

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(testdata, testdata)
highest_correlations_indices_AB = highest_correlations_indices_AB.detach().cpu().numpy()
highest_correlations_values_AB = highest_correlations_values_AB.detach().cpu().numpy()

0


In [None]:
highest_correlations_indices_AB

array([0, 1, 2])

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(testdata, testdata[range(testdata.shape[0])[::-1]])
highest_correlations_indices_AB = highest_correlations_indices_AB.detach().cpu().numpy()
highest_correlations_values_AB = highest_correlations_values_AB.detach().cpu().numpy()

0


In [None]:
highest_correlations_indices_AB

array([1, 0, 1])

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
cosine_similarity(testdata, testdata[range(testdata.shape[0])[::-1]])

array([[0.8164966, 0.8716019, 0.9999999],
       [0.9658243, 1.0000001, 0.8716019],
       [1.       , 0.9658243, 0.8164966]], dtype=float32)

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(testdata, testdata[:, range(testdata.shape[1])[::-1]])
highest_correlations_indices_AB = highest_correlations_indices_AB.detach().cpu().numpy()
highest_correlations_values_AB = highest_correlations_values_AB.detach().cpu().numpy()

0


In [None]:
highest_correlations_indices_AB

array([2, 1, 0])

## large data

In [None]:
# accumulated_outputs[0] == accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]][-1]

In [None]:
mean_A = accumulated_outputs.mean(dim=0, keepdim=True)
std_A = accumulated_outputs.std(dim=0, keepdim=True)
normalized_A = (accumulated_outputs - mean_A) / (std_A + 1e-8)  # Avoid division by zero

In [None]:
mean_B = accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]].mean(dim=0, keepdim=True)
std_B = accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]].std(dim=0, keepdim=True)
normalized_B = (accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]] - mean_B) / (std_B + 1e-8)  # Avoid division by zero

In [None]:
# normalized_A[0] == normalized_B[-1]

In [None]:
# batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, 0:100]) / normalized_A.shape[0]
batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, -100:]) / normalized_A.shape[0]

In [None]:
batch_corr_matrix.shape

torch.Size([512, 100])

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(accumulated_outputs, accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]])
# highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(accumulated_outputs, accumulated_outputs)
highest_correlations_indices_AB = highest_correlations_indices_AB.detach().cpu().numpy()
highest_correlations_values_AB = highest_correlations_values_AB.detach().cpu().numpy()

0
100
200
300
400
500


In [None]:
num_unq_pairs = len(list(set(highest_correlations_indices_AB)))
print("% unique: ", num_unq_pairs / len(highest_correlations_indices_AB))

% unique:  0.6171875


In [None]:
sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

0.0126833261892898

# dont normalize

In [None]:
def batched_correlation(reshaped_activations_A, reshaped_activations_B, batch_size=100):
    # Ensure tensors are on GPU
    if torch.cuda.is_available():
        reshaped_activations_A = reshaped_activations_A.to('cuda')
        reshaped_activations_B = reshaped_activations_B.to('cuda')

    # Normalize columns of A
    # mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)
    # std_A = reshaped_activations_A.std(dim=0, keepdim=True)
    # normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero

    # # Normalize columns of B
    # mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)
    # std_B = reshaped_activations_B.std(dim=0, keepdim=True)
    # normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero

    num_batches = (reshaped_activations_B.shape[1] + batch_size - 1) // batch_size
    max_values = []
    max_indices = []

    for batch in range(num_batches):
        start = batch * batch_size
        print(start)
        end = min(start + batch_size, reshaped_activations_B.shape[1])
        # batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, start:end]) / normalized_A.shape[0]
        batch_corr_matrix = torch.matmul(reshaped_activations_A.t(), reshaped_activations_B[:, start:end])
        max_val, max_idx = batch_corr_matrix.max(dim=0)
        max_values.append(max_val)
        max_indices.append(max_idx)  # Adjust indices for the batch offset

        del batch_corr_matrix
        torch.cuda.empty_cache()

    return torch.cat(max_indices), torch.cat(max_values)

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(accumulated_outputs, accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]])
# highest_correlations_indices_AB, highest_correlations_values_AB = batched_correlation(accumulated_outputs, accumulated_outputs)
highest_correlations_indices_AB = highest_correlations_indices_AB.detach().cpu().numpy()
highest_correlations_values_AB = highest_correlations_values_AB.detach().cpu().numpy()

0
100
200
300
400
500


In [None]:
num_unq_pairs = len(list(set(highest_correlations_indices_AB)))
print("% unique: ", num_unq_pairs / len(highest_correlations_indices_AB))

% unique:  0.02734375


In [None]:
sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

13.206137143366504

# dont batch

In [None]:
def find_all_highest_correlations(reshaped_activations_A, reshaped_activations_B):
    # Ensure tensors are on GPU
    if torch.cuda.is_available():
        reshaped_activations_A = reshaped_activations_A.to('cuda')
        reshaped_activations_B = reshaped_activations_B.to('cuda')

    # Normalize columns of A
    mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)
    std_A = reshaped_activations_A.std(dim=0, keepdim=True)
    normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero

    # Normalize columns of B
    mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)
    std_B = reshaped_activations_B.std(dim=0, keepdim=True)
    normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero

    # Compute correlation matrix
    correlation_matrix = torch.matmul(normalized_A.t(), normalized_B) / normalized_A.shape[0]

    # Handle NaNs by setting them to -inf
    correlation_matrix = torch.where(torch.isnan(correlation_matrix), torch.tensor(float('-inf')).to(correlation_matrix.device), correlation_matrix)

    # Get the highest correlation indices and values
    highest_correlations_values, highest_correlations_indices = correlation_matrix.max(dim=0)

    # Move results back to CPU
    highest_correlations_indices = highest_correlations_indices.cpu().numpy()
    highest_correlations_values = highest_correlations_values.cpu().numpy()

    # return highest_correlations_indices, highest_correlations_values
    return highest_correlations_indices, highest_correlations_values, correlation_matrix

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = find_all_highest_correlations(accumulated_outputs, accumulated_outputs)

In [None]:
sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

0.999799283221364

In [None]:
highest_correlations_values_AB[0]

0.9997986

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = find_all_highest_correlations(accumulated_outputs, accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]])

In [None]:
num_unq_pairs = len(list(set(highest_correlations_indices_AB)))
print("% unique: ", num_unq_pairs / len(highest_correlations_indices_AB))

% unique:  0.625


In [None]:
sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

0.043040892043791246

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = find_all_highest_correlations(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]],
                                                                                                accumulated_outputs)

In [None]:
num_unq_pairs = len(list(set(highest_correlations_indices_AB)))
print("% unique: ", num_unq_pairs / len(highest_correlations_indices_AB))

% unique:  0.625


In [None]:
sum(highest_correlations_values_AB) / len(highest_correlations_values_AB)

0.0430968898217543

In [None]:
range(accumulated_outputs.shape[0])[::-1]

range(4999, -1, -1)

In [None]:
accumulated_outputs.shape[0]

5000

In [None]:
mean_A = accumulated_outputs.mean(dim=0, keepdim=True)
std_A = accumulated_outputs.std(dim=0, keepdim=True)
normalized_A = (accumulated_outputs - mean_A) / (std_A + 1e-8)  # Avoid division by zero

In [None]:
torch.matmul(normalized_A[0].t(), normalized_A[0]) / normalized_A.shape[0]

tensor(0.1412)

In [None]:
torch.matmul(normalized_A[0], normalized_A[0]) / normalized_A.shape[0]

tensor(0.1412)

In [None]:
torch.matmul(accumulated_outputs[0].t(), accumulated_outputs[0])

tensor(0.5132)

In [None]:
torch.matmul(accumulated_outputs[0].t(), accumulated_outputs[0].t())

tensor(0.5132)

In [None]:
torch.matmul(accumulated_outputs[0].t(), accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]][-1])

tensor(0.5132)

## small data

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB = find_all_highest_correlations(testdata, testdata)

In [None]:
highest_correlations_indices_AB

array([0, 1, 2])

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB, corrmat = find_all_highest_correlations(testdata, testdata[range(testdata.shape[0])[::-1]])

In [None]:
corrmat

tensor([[0.3333, 0.6667, 0.4804],
        [0.6667, 0.3333, 0.6405],
        [0.4804, 0.6405, 0.5897]], device='cuda:0')

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB, corrmat = find_all_highest_correlations(testdata, testdata[range(testdata.shape[0])[::-1]])

In [None]:
highest_correlations_indices_AB, highest_correlations_values_AB, corrmat = find_all_highest_correlations(testdata, testdata[:, range(testdata.shape[1])[::-1]])

In [None]:
highest_correlations_indices_AB

array([2, 1, 0])

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
cosine_similarity(testdata, testdata[range(testdata.shape[0])[::-1]])

array([[0.8164966, 0.8716019, 0.9999999],
       [0.9658243, 1.0000001, 0.8716019],
       [1.       , 0.9658243, 0.8164966]], dtype=float32)

# cosine sim

In [None]:
import numpy as np

def find_all_highest_cosine_similarities(reshaped_activations_A, reshaped_activations_B):
    # Normalize rows of A (each vector) for cosine similarity
    norms_A = np.linalg.norm(reshaped_activations_A, axis=1, keepdims=True)
    normalized_A = reshaped_activations_A / (norms_A + 1e-8)  # Avoid division by zero

    # Normalize rows of B (each vector) for cosine similarity
    norms_B = np.linalg.norm(reshaped_activations_B, axis=1, keepdims=True)
    normalized_B = reshaped_activations_B / (norms_B + 1e-8)  # Avoid division by zero

    # Compute cosine similarity matrix
    cosine_similarity_matrix = np.dot(normalized_A, normalized_B.T)

    # Get the highest cosine similarity indices and values
    highest_cosine_values = np.max(cosine_similarity_matrix, axis=1)
    highest_cosine_indices = np.argmax(cosine_similarity_matrix, axis=1)

    return highest_cosine_indices, highest_cosine_values


In [None]:
highest_cosine_indices, highest_cosine_values = find_all_highest_cosine_similarities(accumulated_outputs,
                                                                                                accumulated_outputs)

In [None]:
highest_cosine_indices[:5]

array([0, 1, 2, 1, 4])

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
cosine_similarity(accumulated_outputs[0].reshape(1, -1), accumulated_outputs[0].reshape(1, -1))

array([[1.]], dtype=float32)

In [None]:
highest_cosine_indices, highest_cosine_values = find_all_highest_cosine_similarities(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]],
                                                                                                accumulated_outputs)

In [None]:
highest_cosine_indices

array([3062, 4998,   94, ...,    2, 2183,    0])

In [None]:
sum(highest_cosine_values) / len(highest_cosine_values)

1.0000000455021858

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
cosmat = cosine_similarity(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]].t(), accumulated_outputs.t())

In [None]:
cosmat.shape

(512, 512)

In [None]:
highest_cosine_values = np.max(cosmat, axis=1)
highest_cosine_indices = np.argmax(cosmat, axis=1)

In [None]:
highest_cosine_values[0]

0.044939704

In [None]:
highest_cosine_indices[0]

159

In [None]:
highest_cosine_values = np.max(cosmat, axis=0)
highest_cosine_indices = np.argmax(cosmat, axis=0)

In [None]:
highest_cosine_indices[0]

159

In [None]:
highest_cosine_values[0]

0.04493963

In [None]:
cosine_similarity(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]][159].reshape(1, -1), accumulated_outputs[0].reshape(1, -1))

array([[-0.01252146]], dtype=float32)

In [None]:
cosine_similarity(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]][0].reshape(1, -1), accumulated_outputs[159].reshape(1, -1))

array([[0.04139965]], dtype=float32)

In [None]:
cosine_similarity(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]][-1].reshape(1, -1), accumulated_outputs[0].reshape(1, -1))

array([[1.]], dtype=float32)

In [None]:
cosmat = cosine_similarity(accumulated_outputs.t(), accumulated_outputs.t())

In [None]:
highest_cosine_indices, highest_cosine_values = find_all_highest_cosine_similarities(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]],
                                                                                                accumulated_outputs)
sum(highest_cosine_values) / len(highest_cosine_values)

1.0000000455021858

In [None]:
highest_cosine_indices

array([3062, 4998,   94, ...,    2, 2183,    0])

In [None]:
cosmat = cosine_similarity(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]].t(), accumulated_outputs.t())
highest_cosine_indices, highest_cosine_values = find_all_highest_cosine_similarities(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]],
                                                                                                accumulated_outputs)
sum(highest_cosine_values) / len(highest_cosine_values)

1.0000000455021858

In [None]:
cosmat[0][-1]

0.020572796

In [None]:
cosmat[0][159]

0.044939704

In [None]:
highest_cosine_values[0]

0.9999999

In [None]:
highest_cosine_indices

array([3062, 4998,   94, ...,    2, 2183,    0])

In [None]:
cosmat = cosine_similarity(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]].t(), accumulated_outputs.t())
highest_cosine_indices, highest_cosine_values = find_all_highest_cosine_similarities(accumulated_outputs[range(accumulated_outputs.shape[0])[::-1]],
                                                                                                accumulated_outputs)
sum(highest_cosine_values) / len(highest_cosine_values)

# pearsons scipy

In [None]:
from scipy import stats
stats.pearsonr(accumulated_outputs, accumulated_outputs)

ValueError: shapes (60000,512) and (60000,512) not aligned: 512 (dim 1) != 60000 (dim 0)

# batch corr v2

In [None]:
def batched_correlation_2(reshaped_activations_A, reshaped_activations_B, batch_size=100):
    # Ensure tensors are on GPU
    if torch.cuda.is_available():
        reshaped_activations_A = reshaped_activations_A.to('cuda')
        reshaped_activations_B = reshaped_activations_B.to('cuda')

    # Normalize columns of A
    mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)
    std_A = reshaped_activations_A.std(dim=0, keepdim=True)
    normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero

    # Normalize columns of B
    mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)
    std_B = reshaped_activations_B.std(dim=0, keepdim=True)
    normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero

    # num_batches = (normalized_B.shape[1] + batch_size - 1) // batch_size

    # Process in batches over features of A and B
    A_batch_size = batch_size
    B_batch_size = batch_size
    A_feature_size = normalized_A.shape[1]  # Number of features in A
    B_feature_size = normalized_B.shape[1]  # Number of features in B

    batch_corr_matrix = None
    for A_start in range(0, A_feature_size, A_batch_size):
        A_end = min(A_start + A_batch_size, A_feature_size)
        A_slice = normalized_A[:, A_start:A_end].t()  # Transpose slice of A for matrix multiplication

        row_A_corrs = None
        for B_start in range(0, B_feature_size, B_batch_size):
            B_end = min(B_start + B_batch_size, B_feature_size)
            B_slice = normalized_B[:, B_start:B_end]

            # Compute the batch correlation matrix between slices of A and B
            if row_A_corrs is None:
                row_A_corrs = torch.matmul(A_slice, B_slice) / normalized_A.shape[0]
            else:
                row_A_corrs = torch.cat(row_A_corrs, torch.matmul(A_slice, B_slice) / normalized_A.shape[0], dim = 0)

        if batch_corr_matrix is None:
            batch_corr_matrix = row_A_corrs
        else:
            batch_corr_matrix = torch.cat(batch_corr_matrix, row_A_corrs, dim = 0)

        del row_A_corrs
        torch.cuda.empty_cache()  # Only if you're using CUDA
        gc.collect()

        # Extract max values and indices
        max_val, max_idx = batch_corr_matrix.max(dim=0)
        max_values.append(max_val)
        max_indices.append(max_idx + B_start)  # Adjust B indices for the batch offset

    return torch.cat(max_indices), torch.cat(max_values)

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 100 but got size 12 for tensor number 1 in the list.

In [None]:
row_A_corrs.shape

torch.Size([500, 100])

In [None]:
(torch.matmul(A_slice, B_slice) / normalized_A.shape[0]).shape

torch.Size([100, 12])

In [None]:
highest_correlations_indices_AB

array([  8,   8,  37, ..., 511, 510, 511])

In [None]:
len(highest_correlations_indices_AB)

3072