In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, logging
logging.set_verbosity_error() #Otherwise it will log annoying warnings

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
trained = torch.load('./notebooks/SST-2-BERT-tiny.bin', map_location=torch.device('cpu'))
trained.pop('bert.embeddings.position_ids', None) # Remove unexpected keys
model.load_state_dict(trained , strict=True)

model.eval()

  from .autonotebook import tqdm as notebook_tqdm
  trained = torch.load('./notebooks/SST-2-BERT-tiny.bin', map_location=torch.device('cpu'))


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [2]:
# Load training dataset
from datasets import load_dataset
from torch.utils.data import DataLoader

train_dataset = load_dataset("stanfordnlp/sst2", split="train")

# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(example['sentence'], truncation=True, padding='max_length', max_length=128)
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create a dataloader
train_dataloader = DataLoader(tokenized_train_dataset, batch_size=32)


In [3]:
# Create a dictionary to store LayerNorm Inputs
layernorm_inputs = {
    'layer0_self_output': [],
    'layer0_output': [],
    'layer1_self_output': [],
    'layer1_output': []
}

# Hook function to capture the inputs from each LayerNorm layer
def get_layernorm_input(layer):
    def hook(module, input):
        layernorm_inputs[layer].append(input[0].detach().cpu())
    return hook
layer0_self_output_hook = model.bert.encoder.layer[0].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer0_self_output')
)
layer0_output_hook = model.bert.encoder.layer[0].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer0_output')
)
layer1_self_output_hook = model.bert.encoder.layer[1].attention.output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer1_self_output')
)
layer1_output_hook = model.bert.encoder.layer[1].output.LayerNorm.register_forward_pre_hook(
    get_layernorm_input('layer1_output')
)

In [4]:
from tqdm import tqdm

# Process
attention_mask_list = [] # to excludle padding tokens
with torch.no_grad():
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids'].to("cpu")
        attention_mask = batch['attention_mask'].to("cpu")

        attention_mask_list.append(attention_mask.detach())

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

# Clean the hooks
layer0_self_output_hook.remove()
layer0_output_hook.remove()
layer1_self_output_hook.remove()
layer1_output_hook.remove()

100%|██████████| 2105/2105 [00:26<00:00, 80.72it/s]


In [15]:
import os
import numpy as np

# to excludle padding tokens
all_attention_masks = torch.cat(attention_mask_list, dim=0)
all_attention_masks = all_attention_masks.view(-1)

# Compute the mean & inverse sqrt variance for each LayerNorm layer
for layer, input_list in layernorm_inputs.items():
    # concatenate all inputs for the current layer
    all_inputs = torch.cat(input_list, dim=0)

    # flatten the inputs to merge batch and sequence dimensions
    total_samples, seq_length, hidden_size = all_inputs.shape
    all_inputs = all_inputs.view(-1, hidden_size)

    # exclude padding tokens
    valid_indices = all_attention_masks.nonzero(as_tuple=False).squeeze()
    valid_inputs = all_inputs[valid_indices]

    # Compute mean and variance across all tokens and samples for each feature
    mean = valid_inputs.mean(dim=0).double()
    var = valid_inputs.var(dim=0, unbiased=False).double()

    # Compute the inverse square root of variance + epsilon
    epsilon = 1e-12
    inv_sqrt_var = 1.0 / torch.sqrt(var + epsilon)
    
    #print(layer)
    ln = None
    # Compute vy & normbias
    if (layer == "layer0_self_output"):
        ln = model.bert.encoder.layer[0].attention.output.LayerNorm
    elif (layer == "layer0_output"):
        ln = model.bert.encoder.layer[0].output.LayerNorm
    elif (layer == "layer1_self_output"):
        ln = model.bert.encoder.layer[1].attention.output.LayerNorm
    elif (layer == "layer1_output"):
        ln = model.bert.encoder.layer[1].output.LayerNorm
    
    gamma = ln.weight.clone().detach().double()
    beta = ln.bias.clone().detach().double()
    
    # Compute vy
    vy = (gamma * inv_sqrt_var)
    #normbias = beta - (gamma * mean * inv_sqrt_var)
    normbias = beta
    
    # Expand vy to [128, 128]
    max_length = 55
    """expanded_vy = vy.unsqueeze(0).repeat(max_length, 1)  # [55, 128]
    padding_length = 128 - max_length
    padding = torch.zeros(padding_length, 128, dtype=vy.dtype)
    vy_expanded = torch.cat((expanded_vy, padding), dim=0)  # [128, 128]"""
    # Expanded vy column-wise
    expanded_vy = vy.unsqueeze(1).repeat(1, max_length)  # Shape: [128, 55]
    padding_length = 128 - max_length
    padding = torch.zeros(128, padding_length, dtype=vy.dtype)
    vy_expanded = torch.cat((expanded_vy, padding), dim=1)  # Shape: [128, 128]


    # Optionally, flatten to [128, 128] if needed by the HE circuit
    # vy_expanded = vy_expanded.flatten()  # [16384]
    #
    path = "./train-sst2"
    if not (os.path.exists(path)):
        os.makedirs(path)
        
    # self_output -> selfoutput
    layer = layer.replace('self_output', 'selfoutput')    
        
    # Save the means & inverse sqrt variance to text files
    np.savetxt(f"./{path}/{layer}_mean.txt", mean.numpy())
    np.savetxt(f"./{path}/{layer}_inv_sqrt_var.txt", inv_sqrt_var.numpy())
    
    # Save the vy & normbias
    np.savetxt(f"{path}/{layer}_vy.txt", vy_expanded.numpy(), delimiter=',')
    np.savetxt(f"{path}/{layer}_normbias.txt", normbias.numpy())

print("completed.")

completed.
