In [15]:
import torch
# Load fine-tuned model
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")

In [16]:
# In your Python script that defines the BERT model
from transformers import BertConfig, BertForSequenceClassification

config = BertConfig(
    hidden_size=128,
    num_hidden_layers=2,  # Increase from 2 to 4
    num_attention_heads=2,
    intermediate_size=512,
    max_position_embeddings=512,
    num_labels=2,
)

model = BertForSequenceClassification(config)

In [17]:
print(model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [18]:
state_dict = model.state_dict()

In [19]:
#model.eval()

## Layer 0-3
# layer0_attself_query_weight
# layer0_attself_query_bias
# layer0_attself_key_weight
# layer0_attself_key_bias
# layer0_attself_value_weight
# layer0_attself_value_bias

# layer0_selfoutput_weight
# layer0_selfoutput_bias
# layer0_selfoutput_mean
# layer0_selfoutput_normbias
# layer0_selfoutput_vy

# layer0_intermediate_weight1
# layer0_intermediate_weight2
# layer0_intermediate_weight3
# layer0_intermediate_weight4
# layer0_intermediate_bias

# layer0_output_weight1
# layer0_output_weight2
# layer0_output_weight3
# layer0_output_weight4
# layer0_output_bias
# layer0_output_mean
# layer0_output_normbias
# layer0_output_vy

## Pooler
# pooler_dense_weight
# pooler_dense_bias

## Classifier
# classifier_weight
# classifier_bias

In [23]:
# Load training dataset
from datasets import load_dataset

train_dataset = load_dataset("stanfordnlp/sst2", split="train")

In [24]:
from torch.utils.data import DataLoader

# Tokenize the dataset
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create a dataloader
train_dataloader = DataLoader(tokenized_train_dataset, batch_size=32)

In [25]:
# Create a dictionary to store LayerNorm Inputs
layernorm_inputs = {
    'layer0_self_output': [],
    'layer0_output': [],
    'layer1_self_output': [],
    'layer1_output': []
}

# Hook function to capture the inputs from each LayerNorm layer
def get_layernorm_input(layer):
    def hook(module, input):
        layernorm_inputs[layer].append(input[0].detach().cpu())
    return hook
layer0_self_output_hook = model.bert.encoder.layer[0].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer0_self_output')
)
layer0_output_hook = model.bert.encoder.layer[0].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer0_output')
)
layer1_self_output_hook = model.bert.encoder.layer[1].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer1_self_output')
)
layer1_output_hook = model.bert.encoder.layer[1].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer1_output')
)

In [27]:
from tqdm import tqdm

# Process
attention_mask_list = [] # to excludle padding tokens
with torch.no_grad():
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids'].to("cpu")
        attention_mask = batch['attention_mask'].to("cpu")

        attention_mask_list.append(attention_mask.detach())

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

# Clean the hooks
layer0_self_output_hook.remove()
layer0_output_hook.remove()
layer1_self_output_hook.remove()
layer1_output_hook.remove()

100%|██████████| 2105/2105 [01:08<00:00, 30.89it/s]


In [35]:
import os
import numpy as np

# to excludle padding tokens
all_attention_masks = torch.cat(attention_mask_list, dim=0)
all_attention_masks = all_attention_masks.view(-1)

# Compute the mean & inverse sqrt variance for each LayerNorm layer
for layer, input_list in layernorm_inputs.items():
    # concatenate all inputs for the current layer
    all_inputs = torch.cat(input_list, dim=0)

    # flatten the inputs to merge batch and sequence dimensions
    total_samples, seq_length, hidden_size = all_inputs.shape
    all_inputs = all_inputs.view(-1, hidden_size)

    # exclude padding tokens
    valid_indices = all_attention_masks.nonzero(as_tuple=False).squeeze()
    valid_inputs = all_inputs[valid_indices]

    # Compute mean and variance across all tokens and samples for each feature
    mean = valid_inputs.mean(dim=0).double()
    var = valid_inputs.var(dim=0, unbiased=False).double()

    # Compute the inverse square root of variance + epsilon
    epsilon = 1e-12
    inv_sqrt_var = 1.0 / torch.sqrt(var + epsilon)
    
    #print(layer)
    ln = None
    # Compute vy & normbias
    if (layer == "layer0_self_output"):
        ln = model.bert.encoder.layer[0].attention.output.LayerNorm
    elif (layer == "layer0_output"):
        ln = model.bert.encoder.layer[0].output.LayerNorm
    elif (layer == "layer1_self_output"):
        ln = model.bert.encoder.layer[1].attention.output.LayerNorm
    elif (layer == "layer1_output"):
        ln = model.bert.encoder.layer[1].output.LayerNorm
    
    gamma = ln.weight.clone().detach().double()
    beta = ln.bias.clone().detach().double()

    path = "./weightss"
    if not (os.path.exists(path)):
        os.makedirs(path)

    # Compute vy
    vy = gamma * inv_sqrt_var
    normbias = beta

    # Modify layer name for filenames
    layer_filename = layer.replace('self_output', 'selfoutput')

    # Save vy and normbias to text files
    np.savetxt(f"{path}/{layer_filename}_vy.txt", vy.numpy(), delimiter=',', fmt='%.18e')
    np.savetxt(f"{path}/{layer_filename}_normbias.txt", normbias.numpy(), delimiter=',', fmt='%.18e')

    # Save the means & inverse sqrt variance to text files (optional, for debugging)
    np.savetxt(f"{path}/{layer_filename}_mean.txt", mean[:hidden_size].numpy(), delimiter=',', fmt='%.18e')
    np.savetxt(f"{path}/{layer_filename}_inv_sqrt_var.txt", inv_sqrt_var[:hidden_size].numpy(), delimiter=',', fmt='%.18e')


print("completed.")

completed.


In [37]:
gamma = model.bert.encoder.layer[0].attention.output.LayerNorm.weight.clone().detach().double()
print(gamma.shape)

torch.Size([128])


In [36]:
import numpy as np

# Check sst2 vy
vy = np.loadtxt(f"./weightss/layer0_selfoutput_vy.txt", delimiter=',')

print(vy.shape)

# Count non-zero values
non_zero_count = np.count_nonzero(vy)
print(f"Number of non-zero values: {non_zero_count}")

print(non_zero_count/(128*128))

(128,)
Number of non-zero values: 128
0.0078125
