In [1]:
import torch
from pathlib import Path
from torchvision import datasets

from einops import rearrange

import pickle

from tqdm import tqdm
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from einops import rearrange
import os
import numpy as np

from importlib import reload
from tqdm import tqdm

In [2]:
import sys

PATH = sys.path
newPATH = ['/rcfs/projects/task0_pmml/TRAKfork/trak',] + PATH
sys.path = newPATH

In [3]:
#We are using 1 model because we have only 1 model?
ckpts = [torch.load('/rcfs/projects/task0_pmml/BERT/model_frozen.pt'),
        torch.load('/rcfs/projects/task0_pmml/BERT/one_gpu_development/MANY_BERT_MODELS/BERT-base_SEED1.pt'),
        torch.load('/rcfs/projects/task0_pmml/BERT/one_gpu_development/MANY_BERT_MODELS/BERT-base_SEED2.pt'),
        torch.load('/rcfs/projects/task0_pmml/BERT/one_gpu_development/MANY_BERT_MODELS/BERT-base_SEED3.pt')]

for ckpt in ckpts:
    ckpt.pop('bert.embeddings.position_ids')
    
#I think in earlier versions this weight exists; in the current version it is unused. It
#seems to just be arange...

# Define models and dataset setups

In [4]:
from transformers import BertForSequenceClassification, AdamW, BertConfig


model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels = 2,   
    output_attentions = False,
    output_hidden_states = False,
)

model.to('cuda').eval()

2023-09-19 09:06:44.341837: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [5]:
import pandas as pd
# because the dataset is int tsv format we have to use delimeter.
df = pd.read_csv("../cola_public/raw/in_domain_train.tsv", delimiter='\t', header=None, names=['sentence_sources', 'label', 'label_note', 'sentence'])

# creating a copy so we don't messed up our original dataset.
data=df.copy()

data.drop(['sentence_sources','label_note'],axis=1,inplace=True)
sentences=data.sentence.values
labels = data.label.values

from transformers import BertTokenizer
# using the low level BERT for our task.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

input_ids = []
for sent in sentences:
    # so basically encode tokenizing , mapping sentences to thier token ids after adding special tokens.
    encoded_sent = tokenizer.encode(
                        sent,                      # Sentence which are encoding.
                        add_special_tokens = True, # Adding special tokens '[CLS]' and '[SEP]'

                         )
    
 
    input_ids.append(encoded_sent)
    
from tensorflow.keras.preprocessing.sequence import pad_sequences

MAX_LEN = 128

input_ids = pad_sequences(input_ids, maxlen=MAX_LEN , truncating="post", padding="post")

attention_masks = []

for sent in input_ids:
    
    # Generating attention mask for sentences.
    #   - when there is 0 present as token id we are going to set mask as 0.
    #   - we are going to set mask 1 for all non-zero positive input id.
    att_mask = [int(token_id > 0) for token_id in sent]
    
   
    attention_masks.append(att_mask)
    
from sklearn.model_selection import train_test_split

train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, labels, test_size=0.2, random_state=0)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels,test_size=0.2, random_state=0)

#changing the numpy arrays into tensors for working on GPU. 
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)

train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)

train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# Deciding the batch size for training.

batch_size = 32

#DataLoader for our training set.
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = SequentialSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, shuffle=False)

N_test_DATAPOINTS = 8192 - len(train_inputs)

# DataLoader for our validation(test) set.
validation_data = TensorDataset(validation_inputs[0:N_test_DATAPOINTS], validation_masks[0:N_test_DATAPOINTS], validation_labels[0:N_test_DATAPOINTS])
validation_labels = validation_labels[0:N_test_DATAPOINTS]
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size, shuffle=False)

In [6]:
combined_data = TensorDataset(torch.cat([train_inputs,validation_inputs[0:N_test_DATAPOINTS]]),
                             torch.cat([train_masks,validation_masks[0:N_test_DATAPOINTS]]),
                             torch.cat([train_labels,validation_labels[0:N_test_DATAPOINTS]])
                            )

combined_sampler = SequentialSampler(combined_data)
combined_loader = DataLoader(combined_data, sampler=combined_sampler, batch_size=8, shuffle=False)

In [7]:
model.load_state_dict(ckpts[0])

<All keys matched successfully>

In [8]:
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [9]:
for batch in combined_loader:
    batch = [x.cuda() for x in batch]
    input_shape = batch[0].shape
    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device='cuda')
    batch = [batch[0], token_type_ids, batch[1], batch[2]]
    batch = {'input_ids':batch[0],
            'attention_mask':batch[2],
            'token_type_ids':batch[1],}
    output = model(**batch)
    break

In [10]:
output

SequenceClassifierOutput(loss=None, logits=tensor([[-3.9025,  4.0273],
        [ 3.2860, -3.0037],
        [ 3.4921, -3.5253],
        [-3.8537,  3.6846],
        [-4.2034,  3.9416],
        [-3.3182,  2.9934],
        [-4.1019,  3.8999],
        [ 3.7243, -3.7942]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

# Start TRAK for trNTK computation.

In [11]:
import trak
TRAKer = trak.TRAKer

In [12]:
traker = TRAKer(model=model,
                task='text_pNTK',
                save_dir = '/rcfs/projects/task0_pmml/proj_trNTK/BERT/BERT_pNTK_full/',
                train_set_size=len(train_inputs)+len(validation_inputs[0:N_test_DATAPOINTS]),
                num_classes=None, #probably fails for number of classes = 1; should be number of neurons!
                proj_dim=0, #note, needed to use less memory. 
                use_half_precision=False,
                proj_max_batch_size=8,
                projector=trak.projectors.NoOpProjector())

                             Report any issues at https://github.com/MadryLab/trak/issues
INFO:STORE:Existing model IDs in /rcfs/projects/task0_pmml/proj_trNTK/BERT/BERT_pNTK_full: [0, 1, 2]
INFO:STORE:No model IDs in /rcfs/projects/task0_pmml/proj_trNTK/BERT/BERT_pNTK_full have been finalized.
INFO:STORE:No existing TRAK scores in /rcfs/projects/task0_pmml/proj_trNTK/BERT/BERT_pNTK_full.


# Seem to have to use BasicProjector because the kernel is not built to handle the integer valued inputs to text models> should raise a bug on TRAK?

In [13]:
#traker.projector

In [14]:
#batch

In [15]:
#func_weights = dict(model.named_parameters())
#func_buffers = dict(model.named_buffers())

#torch.func.functional_call(model,(func_weights, func_buffers),args=None,kwargs=batch)

In [16]:
for model_id, ckpt in enumerate(ckpts):
    # TRAKer loads the provided checkpoint and also associates
    # the provided (unique) model_id with the checkpoint.
    traker.load_checkpoint(ckpt, model_id=model_id)

    for batch in tqdm(combined_loader):
        batch = [x.cuda() for x in batch]
        input_shape = batch[0].shape
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device='cuda')
        batch = [batch[0], token_type_ids, batch[1], batch[2]]
        # TRAKer computes features corresponding to the batch of examples,
        # using the checkpoint loaded above.
        traker.featurize(batch=batch, num_samples=batch[0].shape[0])

# Tells TRAKer that we've given it all the information, at which point
# TRAKer does some post-processing to get ready for the next step
# (scoring target examples).
#traker.finalize_features()

100%|█████████████████████████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 4268.07it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 4282.28it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1024/1024 [33:18<00:00,  1.95s/it]
100%|█████████████████████████████████████████████████████████████████████████████| 1024/1024 [3:33:31<00:00, 12.51s/it]


### full pNTK calculation = approx 3:21:39

# Embedding Bert-- only run if you didnt run Trak, Trak eats too much memory

In [5]:
batch_size=32

all_inputs = torch.cat([train_inputs,validation_inputs[0:N_test_DATAPOINTS]]).cuda()
all_masks = torch.cat([train_masks,validation_masks[0:N_test_DATAPOINTS]]).cuda()
all_labels = torch.cat([train_labels,validation_labels[0:N_test_DATAPOINTS]]).cuda()

all_data = TensorDataset(all_inputs, all_masks, all_labels)
all_data_sampler = SequentialSampler(all_data)
all_dataloader = DataLoader(all_data, sampler=all_data_sampler, batch_size=batch_size, shuffle=False)

In [6]:
ALL_NAMES = []
for name,module in model.named_modules():
    print(name)
    if 'activation' in name:
        continue
    if 'dropout' in name:
        continue
    if 'bert'==name:
        continue
    if 'bert.encoder'==name:
        continue
    if 'bert.encoder.layer'==name:
        continue
    if 'relu' in name:
        continue
    if '' == name:
        continue
    #if len(name.split('.')) < 5 and 'bert.encoder.layer' in name:
    #    continue
        
    ALL_NAMES.append(name)
    


bert
bert.embeddings
bert.embeddings.word_embeddings
bert.embeddings.position_embeddings
bert.embeddings.token_type_embeddings
bert.embeddings.LayerNorm
bert.embeddings.dropout
bert.encoder
bert.encoder.layer
bert.encoder.layer.0
bert.encoder.layer.0.attention
bert.encoder.layer.0.attention.self
bert.encoder.layer.0.attention.self.query
bert.encoder.layer.0.attention.self.key
bert.encoder.layer.0.attention.self.value
bert.encoder.layer.0.attention.self.dropout
bert.encoder.layer.0.attention.output
bert.encoder.layer.0.attention.output.dense
bert.encoder.layer.0.attention.output.LayerNorm
bert.encoder.layer.0.attention.output.dropout
bert.encoder.layer.0.intermediate
bert.encoder.layer.0.intermediate.dense
bert.encoder.layer.0.intermediate.intermediate_act_fn
bert.encoder.layer.0.output
bert.encoder.layer.0.output.dense
bert.encoder.layer.0.output.LayerNorm
bert.encoder.layer.0.output.dropout
bert.encoder.layer.1
bert.encoder.layer.1.attention
bert.encoder.layer.1.attention.self
bert.e

In [7]:
for name in ALL_NAMES:
    print(name)

bert.embeddings
bert.embeddings.word_embeddings
bert.embeddings.position_embeddings
bert.embeddings.token_type_embeddings
bert.embeddings.LayerNorm
bert.encoder.layer.0
bert.encoder.layer.0.attention
bert.encoder.layer.0.attention.self
bert.encoder.layer.0.attention.self.query
bert.encoder.layer.0.attention.self.key
bert.encoder.layer.0.attention.self.value
bert.encoder.layer.0.attention.output
bert.encoder.layer.0.attention.output.dense
bert.encoder.layer.0.attention.output.LayerNorm
bert.encoder.layer.0.intermediate
bert.encoder.layer.0.intermediate.dense
bert.encoder.layer.0.intermediate.intermediate_act_fn
bert.encoder.layer.0.output
bert.encoder.layer.0.output.dense
bert.encoder.layer.0.output.LayerNorm
bert.encoder.layer.1
bert.encoder.layer.1.attention
bert.encoder.layer.1.attention.self
bert.encoder.layer.1.attention.self.query
bert.encoder.layer.1.attention.self.key
bert.encoder.layer.1.attention.self.value
bert.encoder.layer.1.attention.output
bert.encoder.layer.1.attention.o

In [23]:
model.hooks = {}

In [66]:
for key in list(model.hooks.keys()):
    model.hooks[key].remove()

In [67]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        if type(output) is tuple:
            if len(output)==1:
                output = output[0]
            else:
                print(output[0].shape)
                print(output[1].shape)
                output = output[0]
        activation[name] = output.detach()
    return hook

In [70]:
all_inputs

tensor([[  101,  5926,  2033,  ...,     0,     0,     0],
        [  101,  1996,  2026,  ...,     0,     0,     0],
        [  101,  2057, 10116,  ...,     0,     0,     0],
        ...,
        [  101,  2065,  2198,  ...,     0,     0,     0],
        [  101,  1996,  2837,  ...,     0,     0,     0],
        [  101,  5863,  3369,  ...,     0,     0,     0]], device='cuda:0',
       dtype=torch.int32)

In [72]:
for modelnum in range(4):
    outer_Em_Kernel = 0
    print('starting: ',modelnum)
    model.load_state_dict(ckpts[modelnum])
    model.eval()
    for k,NAME in tqdm(enumerate(ALL_NAMES)):
        if os.path.exists(f'/rcfs/projects/task0_pmml/BERT/Em_kernel_components/{modelnum}/{NAME}-{k}.pt'):
            continue
        EM_Component = torch.zeros((len(all_masks),len(all_masks)),device='cpu')
        
        for name, module in model.named_modules():
            if name == NAME:
                model.hooks[NAME] = module.register_forward_hook(get_activation(NAME))
        
        with torch.no_grad():
            for i in range(8):
                activation = {}
                outputs = model(all_inputs[i*1024:(i+1)*1024], 
                                token_type_ids=None, 
                                attention_mask=all_masks[i*1024:(i+1)*1024], 
                                labels=all_labels[i*1024:(i+1)*1024],
                                output_hidden_states=False)
                X1_activation = activation[NAME].reshape(1024,-1)
                for j in range(8):
                    activation = {}
                    outputs = model(all_inputs[j*1024:(j+1)*1024], 
                                    token_type_ids=None, 
                                    attention_mask=all_masks[j*1024:(j+1)*1024], 
                                    labels=all_labels[j*1024:(j+1)*1024],
                                    output_hidden_states=False)
                    X2_activation = activation[NAME].reshape(1024,-1)



                    component = torch.matmul(X1_activation,X2_activation.T).cpu()
                    EM_Component[i*1024:(i+1)*1024,j*1024:(j+1)*1024] = component
            outer_Em_Kernel+= EM_Component
            torch.save(EM_Component,f'/rcfs/projects/task0_pmml/BERT/Em_kernel_components/{modelnum}/{NAME}-{k}.pt')
            model.hooks[NAME].remove()
    #torch.save(outer_Em_Kernel,f'/rcfs/projects/task0_pmml/BERT/Em_kernels/seed{modelnum}.pt')


starting:  0


19it [02:38,  8.35s/it]


KeyboardInterrupt: 