In [6]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pickle

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [None]:
class LatentBert(torch.nn.Module):

    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.n_lauyers = len(self.base_model.bert.encoder.layer)

    def to(self, device):
        self.base_model.to(device)
        return self

    def forward(self, x, aggregate=True):
        
        x = self.base_model.bert.embeddings(x)
        latent_reps = torch.zeros((x.shape[0], x.shape[2],self.n_layers)
        )
        for i, layer in enumerate(self.base_model.bert.encoder.layer) :
            x = layer(x)[0]
            latent_reps[:,:,i] = x[:,0,:] #Take the hidden state of the CLS token for sentence-level classification
        
        x_ = self.base_model.bert.pooler(x)
        x = self.base_model.classifier(x_)
                
        return {"embeddings" : torch.mean(latent_reps, axis=-1) if aggregate else latent_reps, "logits": x, "attention": x_}

In [8]:
dataset_names = {'imdb':'imdb', 'sst2':'sst2', 'trec':'trec', '20ng':'SetFit/20_newsgroups'}
collate_fns = {
    'imdb' : lambda x : [el['text'] for el in x], 
    'sst2' : lambda x: [el['sentence'] for el in x], 
    'trec' : lambda x : [el['text'] for el in x],  
    '20ng' : lambda x : [el['text'].replace('\n', ' ').replace("\\", '') for el in x ]
}

In [None]:
latent_bert = LatentBert(model)
embeddings = []

for ds_name, ds in dataset_names :
    dataset = load_dataset(ds)
    data_loader = DataLoader(dataset = dataset['test'], batch_size=16, collate_fn=collate_fns[ds_name] )
    embeddings = []
    logits = []
    attentions = []
    for batch in tqdm(data_loader):
        batch_encoded_input = tokenizer(batch, return_tensors='pt', truncation=True, padding=True)['input_ids'].to('cuda')
        outputs = latent_bert.forward(batch_encoded_input, aggregate=False)
        embeddings.append(outputs['embeddings'].cpu().detach())
        logits.append(outputs['logits'].cpu().detach())
        attentions.append(outputs['attention'].cpu().detach())
        
        with open('./pickle_files/embeddings_ood_test_{}.pkl'.format(ds_name), 'wb') as f:
            pickle.dump(embeddings, f)

        with open('./pickle_files/logits_ood_test_{}.pkl'.format(ds_name), 'wb') as f:
            pickle.dump(logits, f)

        with open('./pickle_files/attentions_ood_test_{}.pkl'.format(ds_name), 'wb') as f:
            pickle.dump(attentions, f)
        
        
        
    
    