In [171]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

# ============================
# Step 1: Load the Model and Dataset
# ============================

device = "cpu"
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")

# Load fine-tuned model
trained = torch.load('./notebooks/SST-2-BERT-tiny.bin', map_location=torch.device(device))
trained.pop('bert.embeddings.position_ids', None)  # Remove unexpected keys if any
model.load_state_dict(trained, strict=True)

# Set the model to evaluation mode
model.eval()

# Load the SST-2 training dataset
train_dataset = load_dataset("stanfordnlp/sst2", split="train")
valid_dataset = load_dataset("stanfordnlp/sst2", split="validation")

# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(example['sentence'], truncation=True, padding='max_length', max_length=55)

# Apply tokenization to the dataset
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create a DataLoader for batching
train_dataloader = DataLoader(tokenized_train_dataset, batch_size=32)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trained = torch.load('./notebooks/SST-2-BERT-tiny.bin', map_location=torch.device(device))


In [11]:
tokenized_train_dataset.shape

(67349, 6)

In [135]:
train_dataset['sentence'][0]

'hide new secretions from the parental units '

In [136]:
len(train_dataset['sentence'])

67349

In [None]:
import math

mean_distribution = []
var_distribution = []

for sentence in train_dataset['sentence']:
    text = "[CLS] " + sentence + " [SEP]"

    tokenized = tokenizer(text)
    tokenized_text = tokenizer.tokenize(text)
    segments_ids = [1] * len(tokenized_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    xx = model.embeddings(tokens_tensor, segments_tensors)
    xx = model.encoder.layer[0].attention.self(xx)[0].double()
    
    w_output_dense = model.encoder.layer[0].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
    b_output_dense = model.encoder.layer[0].attention.output.dense.bias.clone().detach().double()

    xx = torch.matmul(xx, w_output_dense) + b_output_dense
    xx = xx + model.embeddings(tokens_tensor, segments_tensors)

    means = []
    variances = []

    for i in range(55):
        xi = xx.squeeze()[i]
        means.append(torch.mean(xi.squeeze()).item())
        variances.append(1 / math.sqrt(torch.var(xi.squeeze()).item()))
    
    mean_distribution.append(np.array(means))
    var_distribution.append(np.array(variances))

In [146]:
model.bert.encoder.layer[0].attention.self

BertSdpaSelfAttention(
  (query): Linear(in_features=128, out_features=128, bias=True)
  (key): Linear(in_features=128, out_features=128, bias=True)
  (value): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [191]:
model.double()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [192]:
import math

mean_distribution = []
var_distribution = []

for sentence in train_dataset['sentence']:
    text = "[CLS] " + sentence + " [SEP]"

    tokenized = tokenizer(text)
    tokenized_text = tokenizer.tokenize(text)
    segments_ids = [1] * len(tokenized_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    # Embeddings
    xx = model.bert.embeddings(tokens_tensor, segments_tensors)

    # Self-Attention
    #xx = model.bert.encoder.layer[0].attention.self(xx)[0].double()
    fin = model.bert.encoder.layer[0].attention.self(xx)[0].double()

    w_output_dense = model.bert.encoder.layer[0].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
    b_output_dense = model.bert.encoder.layer[0].attention.output.dense.bias.clone().detach().double()

    fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
    fin2_backup = fin2.clone()
    fin2_backup = fin2_backup + model.bert.embeddings(tokens_tensor, segments_tensors)

    mean_0_0 = []
    var_0_0 = []

    fin3_whole = []
    for i in range(len(fin2_backup.squeeze())):
        fin2 = fin2_backup.squeeze()[i]

        current_mean = torch.mean(fin2.squeeze()).item()
        current_var = 1 / math.sqrt(torch.var(fin2.squeeze()).item())

        # save mean and variance
        mean_0_0.append(current_mean)
        var_0_0.append(current_var)

        fin3_corr = (fin2.squeeze() - current_mean) * current_var

        w_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.weight.clone().detach().double()
        b_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.bias.clone().detach().double()

        fin3_corr = fin3_corr, w_output_layernorm + b_output_layernorm
        fin3_whole.append(fin3_corr)

    mean_distribution.append(np.array(mean_0_0))
    var_distribution.append(np.array(var_0_0))

KeyboardInterrupt: 

In [243]:
import math

mean_distribution_0_0 = []
var_distribution_0_0 = []
mean_distribution_0_1 = []
var_distribution_0_1 = []
mean_distribution_1_0 = []
var_distribution_1_0 = []
mean_distribution_1_1 = []
var_distribution_1_1 = []

for sentence in tqdm(valid_dataset['sentence']):
    text = "[CLS] " + sentence + " [SEP]"

    tokenized = tokenizer(text)
    tokenized_text = tokenizer.tokenize(text)
    segments_ids = [1] * len(tokenized_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    # Embeddings
    xx = model.bert.embeddings(tokens_tensor, segments_tensors)
    original_input_tensor = xx.double()
    input_tensor = xx.double()

    # Self-Attention
    #xx = model.bert.encoder.layer[0].attention.self(xx)[0].double()
    fin = model.bert.encoder.layer[0].attention.self(xx)[0].double()

    w_output_dense = model.bert.encoder.layer[0].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
    b_output_dense = model.bert.encoder.layer[0].attention.output.dense.bias.clone().detach().double()

    fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
    fin2_backup = fin2.clone()
    fin2_backup = fin2_backup + original_input_tensor

    mean_0_0 = []
    var_0_0 = []

    fin3_whole = []
    for i in range(len(original_input_tensor.squeeze())):
        fin2 = fin2_backup.squeeze()[i]

        current_mean = torch.mean(fin2.squeeze()).item()
        current_var = 1 / math.sqrt(torch.var(fin2.squeeze()).item())

        # save mean and variance
        mean_0_0.append(current_mean)
        var_0_0.append(current_var)

        fin3_corr = (fin2.squeeze() - current_mean) * current_var

        w_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
        b_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.bias.clone().detach().double()

        fin3_corr = fin3_corr * w_output_layernorm + b_output_layernorm
        fin3_whole.append(fin3_corr)

    mean_distribution_0_0.append(np.array(mean_0_0))
    var_distribution_0_0.append(np.array(var_0_0))

    fin3_whole = torch.cat(tuple(fin3_whole), 0).unsqueeze(0)
    fin_4 = torch.matmul(fin3_whole, model.bert.encoder.layer[0].intermediate.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[0].intermediate.dense.bias

    fin_5 = torch.nn.functional.gelu(fin_4)
    fin_6 = torch.matmul(fin_5, model.bert.encoder.layer[0].output.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[0].output.dense.bias
    fin_6 = fin_6 + fin3_whole

    mean_0_1 = []
    var_0_1 = []

    fin7_whole = []
    for i in range(len(input_tensor.squeeze())):
        fin7 = fin_6.squeeze()[i]

        current_mean = torch.mean(fin7.squeeze()).item()
        current_var = 1 / math.sqrt(torch.var(fin7.squeeze()).item())

        # save mean and variance
        mean_0_1.append(current_mean)
        var_0_1.append(current_var)

        fin7_corr = (fin7.squeeze() - current_mean) * current_var

        w_output_layernorm = model.bert.encoder.layer[0].output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
        b_output_layernorm = model.bert.encoder.layer[0].output.LayerNorm.bias.clone().detach().double()

        fin7_corr = fin7_corr * w_output_layernorm + b_output_layernorm
        fin7_whole.append(fin7_corr)
        

    mean_distribution_0_1.append(np.array(mean_0_1))
    var_distribution_0_1.append(np.array(var_0_1))

    fin7_whole = torch.cat(tuple(fin7_whole), 0).unsqueeze(0)

    original_input_tensor = fin7_whole

    fin = model.bert.encoder.layer[1].attention.self(fin7_whole)[0].double()

    w_output_dense = model.bert.encoder.layer[1].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
    b_output_dense = model.bert.encoder.layer[1].attention.output.dense.bias.clone().detach().double()

    fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
    fin2_backup = fin2.clone()
    fin2_backup = fin2_backup + original_input_tensor

    mean_1_0 = []
    var_1_0 = []

    fin3_whole = []
    for i in range(len(original_input_tensor.squeeze())):
        fin2 = fin2_backup.squeeze()[i]

        current_mean = torch.mean(fin2.squeeze()).item()
        current_var = 1 / math.sqrt(torch.var(fin2.squeeze()).item())

        mean_1_0.append(current_mean)
        var_1_0.append(current_var)

        fin3_corr = (fin2.squeeze() - current_mean) * current_var

        w_output_layernorm = model.bert.encoder.layer[1].attention.output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
        b_output_layernorm = model.bert.encoder.layer[1].attention.output.LayerNorm.bias.clone().detach().double()

        fin3_corr = fin3_corr * w_output_layernorm + b_output_layernorm
        fin3_whole.append(fin3_corr)

    mean_distribution_1_0.append(np.array(mean_1_0))
    var_distribution_1_0.append(np.array(var_1_0))

    fin3_whole = torch.cat(tuple(fin3_whole), 0).unsqueeze(0)
    fin_4 = torch.matmul(fin3_whole, model.bert.encoder.layer[1].intermediate.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[1].intermediate.dense.bias

    fin_5 = torch.nn.functional.gelu(fin_4)
    fin_6 = torch.matmul(fin_5, model.bert.encoder.layer[1].output.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[1].output.dense.bias
    fin_6 = fin_6 + fin3_whole
    
    mean_1_1 = []
    var_1_1 = []

    fin7_whole = []
    for i in range(len(input_tensor.squeeze())):
        fin7 = fin_6.squeeze()[i]

        current_mean = torch.mean(fin7.squeeze()).item()
        current_var = 1 / math.sqrt(torch.var(fin7.squeeze()).item())

        mean_1_1.append(current_mean)
        var_1_1.append(current_var)

        fin7_corr = (fin7.squeeze() - current_mean) * current_var

        w_output_layernorm = model.bert.encoder.layer[1].output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
        b_output_layernorm = model.bert.encoder.layer[1].output.LayerNorm.bias.clone().detach().double()

        fin7_corr = fin7_corr * w_output_layernorm + b_output_layernorm
        fin7_whole.append(fin7_corr.unsqueeze(0))

    mean_distribution_1_1.append(np.array(mean_1_1))
    var_distribution_1_1.append(np.array(var_1_1))

    fin7_whole = torch.cat(tuple(fin7_whole), 0).unsqueeze(0)

100%|██████████| 872/872 [00:05<00:00, 153.31it/s]


In [195]:
len(mean_distribution)

17616

In [196]:
len(mean_distribution[1])

11

In [197]:
# find each token position mean of mean_distribution
total_mean_0_0 = []
sum = 0;
for i in range(len(mean_distribution)):
    sum += np.mean(mean_distribution[i][0])

print(sum/len(mean_distribution))

-0.03460026379046557


In [232]:
import numpy as np

def precision(correct, approx):
    if isinstance(approx, list):
        approx = np.array(approx)
    absolute = np.sum(np.abs(correct - approx)) / len(correct)
    relative = absolute / (np.sum(np.abs(correct)) / len(correct))
    return 1 - relative

In [235]:
real_mean_0_0 = np.loadtxt("./weights-sst2/layer0_selfoutput_mean.txt")
real_mean_0_1 = np.loadtxt("./weights-sst2/layer0_output_mean.txt")
real_mean_1_0 = np.loadtxt("./weights-sst2/layer1_selfoutput_mean.txt")
real_mean_1_1 = np.loadtxt("./weights-sst2/layer1_output_mean.txt")

In [245]:
# Calculate mean for each token position across all samples
max_length = max(len(sample) for sample in mean_distribution_0_1)
total_means = np.zeros(max_length)
counts = np.zeros(max_length)

for sample in mean_distribution_0_1:
    for i, value in enumerate(sample):
        total_means[i] += np.mean(value)
        counts[i] += 1

# Calculate average mean for each position
average_means = total_means / counts

# Print the results
for i, mean in enumerate(average_means):
    print(f"Position {i} mean: {mean}")

Position 0 mean: -0.09405281883323419
Position 1 mean: 0.03478227947344718
Position 2 mean: 0.03955231243180054
Position 3 mean: 0.04100700399849076
Position 4 mean: 0.044454166752222586
Position 5 mean: 0.04935611358633444
Position 6 mean: 0.04950300960598524
Position 7 mean: 0.04804027485051588
Position 8 mean: 0.04926838412151814
Position 9 mean: 0.04887909554475611
Position 10 mean: 0.047896674133086015
Position 11 mean: 0.051112126845139064
Position 12 mean: 0.051096849669458266
Position 13 mean: 0.04947306479410882
Position 14 mean: 0.048564172826746514
Position 15 mean: 0.04752711313564156
Position 16 mean: 0.045470358193168754
Position 17 mean: 0.045431951713177085
Position 18 mean: 0.04673999271555818
Position 19 mean: 0.05176381057983726
Position 20 mean: 0.049251440846861526
Position 21 mean: 0.04900623448904927
Position 22 mean: 0.0484905850906113
Position 23 mean: 0.049154945098170706
Position 24 mean: 0.05004565285994915
Position 25 mean: 0.051870537056842095
Position 26 

In [250]:
import numpy as np

# Calculate mean for each token position across all samples
max_length = max(len(sample) for sample in mean_distribution_0_0)
total_means = np.zeros(128)  # Initialize with 128 positions
counts = np.zeros(128)  # Initialize with 128 positions

for sample in mean_distribution_0_0:
    for i, value in enumerate(sample):
        if i < 128:  # Only consider up to 128 positions
            total_means[i] += np.mean(value)
            counts[i] += 1

# Calculate average mean for each position
average_means = np.zeros(128)
for i in range(128):
    if counts[i] > 0:
        average_means[i] = total_means[i] / counts[i]
    # If count is 0, the average_means[i] remains 0

# Print the results
"""for i, mean in enumerate(average_means):
    print(f"Position {i} mean: {mean}")"""

# Calculate precision
precision_value = precision(real_mean_0_0, average_means)
print(f"Precision: {precision_value}")

Precision: 0.9999998748343766
