In [1]:
import torch
# Load fine-tuned model
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# In your Python script that defines the BERT model
from transformers import BertConfig, BertForSequenceClassification

config = BertConfig(
    hidden_size=128,
    num_hidden_layers=4,  # Increase from 2 to 4
    num_attention_heads=2,
    intermediate_size=512,
    max_position_embeddings=512,
    num_labels=2,
)

model = BertForSequenceClassification(config)

In [3]:
print(model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [4]:
state_dict = model.state_dict()

In [5]:
#model.eval()

## Layer 0-3
# layer0_attself_query_weight
# layer0_attself_query_bias
# layer0_attself_key_weight
# layer0_attself_key_bias
# layer0_attself_value_weight
# layer0_attself_value_bias

# layer0_selfoutput_weight
# layer0_selfoutput_bias
# layer0_selfoutput_mean
# layer0_selfoutput_normbias
# layer0_selfoutput_vy

# layer0_intermediate_weight1
# layer0_intermediate_weight2
# layer0_intermediate_weight3
# layer0_intermediate_weight4
# layer0_intermediate_bias

# layer0_output_weight1
# layer0_output_weight2
# layer0_output_weight3
# layer0_output_weight4
# layer0_output_bias
# layer0_output_mean
# layer0_output_normbias
# layer0_output_vy

## Pooler
# pooler_dense_weight
# pooler_dense_bias

## Classifier
# classifier_weight
# classifier_bias

In [13]:
# Load training dataset
from datasets import load_dataset

dataset = load_dataset("stanfordnlp/sst2", split="train")

In [14]:
# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(example['sentence'], truncation=True, padding='max_length', max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

Map: 100%|█| 67349/67349 [00:03<00:00, 16878.79


In [15]:
from torch.utils.data import DataLoader

# Create a dataloader
dataloader = DataLoader(tokenized_dataset, batch_size=32)

In [19]:
# Extract the weights and biases

In [17]:
import os
import numpy as np

def save_parameter(tensor, filename):
    # Convert the tensor to a NumPy array
    array = tensor.detach().cpu().numpy()
    # Save the array to a text file
    np.savetxt(filename, array)

In [18]:
# Create a directory to save the weights
os.makedirs('weights', exist_ok=True)

# For each encoder layer
for layer_num in range(4):
    prefix = f'bert.encoder.layer.{layer_num}.'
    layer_prefix = f'layer{layer_num}_'
    
    # Self-Attention Weights and Biases
    query_weight = state_dict[prefix + 'attention.self.query.weight']
    query_bias = state_dict[prefix + 'attention.self.query.bias']
    key_weight = state_dict[prefix + 'attention.self.key.weight']
    key_bias = state_dict[prefix + 'attention.self.key.bias']
    value_weight = state_dict[prefix + 'attention.self.value.weight']
    value_bias = state_dict[prefix + 'attention.self.value.bias']
    
    save_parameter(query_weight, f'weights/{layer_prefix}attself_query_weight.txt')
    save_parameter(query_bias, f'weights/{layer_prefix}attself_query_bias.txt')
    save_parameter(key_weight, f'weights/{layer_prefix}attself_key_weight.txt')
    save_parameter(key_bias, f'weights/{layer_prefix}attself_key_bias.txt')
    save_parameter(value_weight, f'weights/{layer_prefix}attself_value_weight.txt')
    save_parameter(value_bias, f'weights/{layer_prefix}attself_value_bias.txt')
    
    # Self-Output Weights and Biases
    self_output_weight = state_dict[prefix + 'attention.output.dense.weight']
    self_output_bias = state_dict[prefix + 'attention.output.dense.bias']
    
    save_parameter(self_output_weight, f'weights/{layer_prefix}selfoutput_weight.txt')
    save_parameter(self_output_bias, f'weights/{layer_prefix}selfoutput_bias.txt')
    
    # Intermediate Weights and Biases
    intermediate_weight = state_dict[prefix + 'intermediate.dense.weight']
    intermediate_bias = state_dict[prefix + 'intermediate.dense.bias']
    
    # Since the intermediate weight is 128 x 512, we'll split it into four 128 x 128 blocks
    intermediate_weight_blocks = torch.split(intermediate_weight, 128, dim=1)
    for i, block in enumerate(intermediate_weight_blocks):
        save_parameter(block, f'weights/{layer_prefix}intermediate_weight{i+1}.txt')
    save_parameter(intermediate_bias, f'weights/{layer_prefix}intermediate_bias.txt')
    
    # Output Weights and Biases
    output_weight = state_dict[prefix + 'output.dense.weight']
    output_bias = state_dict[prefix + 'output.dense.bias']
    
    # Split the output weight into four 128 x 128 blocks
    output_weight_blocks = torch.split(output_weight, 128, dim=0)
    for i, block in enumerate(output_weight_blocks):
        save_parameter(block, f'weights/{layer_prefix}output_weight{i+1}.txt')
    save_parameter(output_bias, f'weights/{layer_prefix}output_bias.txt')

In [None]:
# Precompute mean & inv_sqrt_var & vy & normbias

In [21]:
# Create a dictionary to store LayerNorm Inputs
layernorm_inputs = {
    'layer0_self_output': [],
    'layer0_output': [],
    'layer1_self_output': [],
    'layer1_output': [],
    'layer2_self_output': [],
    'layer2_output': [],
    'layer3_self_output': [],
    'layer3_output': []
}

In [22]:
# Hook function to capture the inputs from each LayerNorm layer
def get_layernorm_input(layer):
    def hook(module, input):
        layernorm_inputs[layer].append(input[0].detach())
    return hook
layer0_self_output_hook = model.bert.encoder.layer[0].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer0_self_output')
)
layer0_output_hook = model.bert.encoder.layer[0].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer0_output')
)
layer1_self_output_hook = model.bert.encoder.layer[1].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer1_self_output')
)
layer1_output_hook = model.bert.encoder.layer[1].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer1_output')
)
layer2_self_output_hook = model.bert.encoder.layer[2].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer2_self_output')
)
layer2_output_hook = model.bert.encoder.layer[2].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer2_output')
)
layer3_self_output_hook = model.bert.encoder.layer[3].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer3_self_output')
)
layer3_output_hook = model.bert.encoder.layer[3].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer3_output')
)

In [25]:
from tqdm import tqdm

model.eval()
# Process
attention_mask_list = [] # to excludle padding tokens
with torch.no_grad():
    for batch in tqdm(dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        
        attention_mask_list.append(attention_mask.detach())
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

# Clean the hooks
layer0_self_output_hook.remove()
layer0_output_hook.remove()
layer1_self_output_hook.remove()
layer1_output_hook.remove()
layer2_self_output_hook.remove()
layer2_output_hook.remove()
layer3_self_output_hook.remove()
layer3_output_hook.remove()

  0%|                 | 0/2105 [00:00<?, ?it/s]


KeyError: 'layer0_ln1'

In [None]:
# to excludle padding tokens
all_attention_masks = torch.cat(attention_mask_list, dim=0)
all_attention_masks = all_attention_masks.view(-1)

# Compute the mean & inverse sqrt variance for each LayerNorm layer
for layer, input_list in layernorm_inputs.items():
    # concatenate all inputs for the current layer
    all_inputs = torch.cat(input_list, dim=0)

    # flatten the inputs to merge batch and sequence dimensions
    total_samples, seq_length, hidden_size = all_inputs.shape
    all_inputs = all_inputs.view(-1, hidden_size)
    
    # exclude padding tokens
    valid_indices = all_attention_masks.nonzero(as_tuple=False).squeeze()
    valid_inputs = all_inputs[valid_indices]

    # Compute mean and variance across all tokens and samples for each feature
    mean = valid_inputs.mean(dim=0)
    var = valid_inputs.var(dim=0, unbiased=False)

    # Compute the inverse square root of variance + epsilon
    epsilon = 1e-12
    inv_sqrt_var = 1.0 / torch.sqrt(var + epsilon)

    #
    path = "./weightss"
    if not (os.path.exists(path)):
        os.makedirs(path)
    # Save the means & inverse sqrt variance to text files
    np.savetxt(f"./{path}/{layer}_mean.txt", mean.numpy())
    np.savetxt(f"./{path}/{layer}_inv_sqrt_var.txt", inv_sqrt_var.numpy())

print("completed.")