In [36]:
from transformers import RobertaModel, RobertaTokenizer, RobertaConfig
from torch.nn import Module
from torch.utils.data import DataLoader
import datasets
from icecream import ic
from tqdm import tqdm
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

In [28]:
class RobertaWrapper(Module):

    """
    Wrapper on roberta that gives mean-pooled representations of each layer in a list
    """

    def __init__(self, device="cpu"):
        super().__init__()

        self.device = device
        self.model_obj = RobertaModel.from_pretrained(
            "roberta-base").eval()
        self.model_obj.eval()
        self.tokenizer_obj = RobertaTokenizer.from_pretrained("roberta-base")
        self.config_obj = RobertaConfig.from_pretrained("roberta-base")

    def forward(self, input_text):

        encoder_ret = self.tokenizer_obj(
            input_text, truncation=True, return_tensors="pt", padding=True)

        encoder_text_ids = encoder_ret.input_ids.to(self.device)
        attention_mask = encoder_ret.attention_mask.to(self.device) # 1 for not pad

        ic(encoder_text_ids.device)
        encoder_states = self.model_obj(
            encoder_text_ids, output_hidden_states=True, attention_mask=attention_mask)

        ic(self.model_obj.device)
        hs_tuple = encoder_states["hidden_states"]

        mean_pooled_all_layers = []

        for layer, hs in enumerate(hs_tuple):
            ic(hs_tuple[layer].size())
            # hs = hs_tuple[layer] # (batch_size x sequence_length x dimension)
            hs_masked = hs * attention_mask[:, :, None] # ideally zeros out the pad associated representations
            ic(hs_masked.size())
            seq_lengths = attention_mask.sum(dim=1) # each line here represents sequence length

            hs_masked_sum = hs_masked.sum(dim=1)
            hs_avg = hs_masked_sum / seq_lengths[:, None]
            mean_pooled_all_layers.append(hs_avg)

        return mean_pooled_all_layers



In [55]:
# test
model_wrapped = RobertaWrapper()
test_dataset_xor = datasets.load_dataset("data_scripts/data_xor.py", add_sep=False)["train"]
test_dataloader = DataLoader(test_dataset_xor, batch_size=3)

print(next(iter(test_dataloader)))
output = model_wrapped(next(iter(test_dataloader))["content"])
print(len(output))
print(output[0].size())


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading and preparing dataset data_xor/default to /home/mrcreator/research/main_thread/fresh_repo/{'cache_dir': None, 'config_name': None, 'data_dir': None, 'data_files': None, 'hash': '70230897f97c0425f1c23dd623529f9f6e057014c5f1753a92cf17966f2c89f0', 'features': None, 'use_auth_token': None, 'base_path': 'data_scripts', 'add_sep': False}/data_xor/default/0.0.0...


                                                                        

Dataset data_xor downloaded and prepared to /home/mrcreator/research/main_thread/fresh_repo/{'cache_dir': None, 'config_name': None, 'data_dir': None, 'data_files': None, 'hash': '70230897f97c0425f1c23dd623529f9f6e057014c5f1753a92cf17966f2c89f0', 'features': None, 'use_auth_token': None, 'base_path': 'data_scripts', 'add_sep': False}/data_xor/default/0.0.0. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 203.91it/s]


{'id_': tensor([1, 2, 3]), 'content': ['Library of Alexandria is located in Alexandria. Library of Alexandria is located in Alexandria.', 'Library of Alexandria is located in Alexandria. Library of Alexandria is not located in Alexandria.', 'Library of Alexandria is not located in Alexandria. Library of Alexandria is located in Alexandria.'], 'label': tensor([0, 1, 1])}
13
torch.Size([3, 768])


In [53]:
def get_hidden_states_many_examples(model, data, n=100, layer=-1):
    """
    Takes a bunch of sequences and runs them through RoBERTa to generate the mean-pooled hidden states.

    This is unbatched and kept inefficient for simplicity
    """
    # setup
    model.eval()
    all_hidden_states, all_labels = [], []
    # all_hidden_states: will have elements for each RoBERTa layer, each element represents the mean-pooled representations for the whole data at that layer


    # loop
    for idx in tqdm(range(n)):

        text, true_label = data[idx]["content"], data[idx]["label"]
        print(data[idx]["content"])
        print(data[idx]["label"])

        # get hidden states
        with torch.no_grad():
            outs = model(text)
        # outs: [hidden states]

        # initialize if empty
        if len(all_hidden_states) == 0:
            for i in range(len(outs)):
                all_hidden_states.append([])


        # collect
        for i, hidden_state in enumerate(outs):
            all_hidden_states[i].append(hidden_state)

        all_labels.append(true_label)

    ic(len(all_hidden_states))
    ic(len(all_hidden_states[0]))
    ic(all_hidden_states[0][0].size())
    ic(torch.cat(all_hidden_states[0], dim=0).size())

    all_hidden_states = [torch.cat(all_hidden_states[i], dim=0) for i in range(len(all_hidden_states))]


    return all_hidden_states, all_labels

In [58]:
ic.disable()
outs = get_hidden_states_many_examples(model_wrapped, test_dataset_xor, n=4)
print(len(outs[0]))
print(outs[0][0].size())

 50%|█████     | 2/4 [00:00<00:00, 19.53it/s]

Library of Alexandria is located in Alexandria. Library of Alexandria is located in Alexandria.
0
Library of Alexandria is located in Alexandria. Library of Alexandria is not located in Alexandria.
1
Library of Alexandria is not located in Alexandria. Library of Alexandria is located in Alexandria.
1
Library of Alexandria is not located in Alexandria. Library of Alexandria is not located in Alexandria.
0


100%|██████████| 4/4 [00:00<00:00, 17.41it/s]

13
torch.Size([4, 768])





In [1]:
def run_experiment_across_layers(experiment, train_input, train_labels, test_input, test_labels):
    """
    Runs a probing experiment over representations from all layers of the model.
    The whole thing works on cached embeddings

    experiment: method (train: Tensor, test: Tensor, label_train: Tensor, label_test: Tensor) -> (fit_model, metrics). Each experiment will fit _some_ model on the data and return the model and the results
    train_input: list of 13 elements, each of which is a tensor of size (num_datapoints, embedding_dim)
    train_labels: tensor (num_datapoints, )
    test_input: same format as train_input
    test_labels: same format as train_labels
    """

    list_of_results = []
    list_of_probing_models = []

    for i in range(len(train_input)):
        train_current_layer = train_input[i]
        test_current_layer = test_input[i]

        model, metrics = experiment(train_current_layer, test_current_layer, train_labels, test_labels)

        list_of_results.append(metrics)
        list_of_probing_models.append(model)

    return list_of_probing_models, list_of_results


In [3]:
def probe_experiment(train_input, test_input, train_labels, test_labels, probe_model):
    """
    Gets an initialized probe model and fits it on data and runs some experiments
    expected to be curried and sent as a callback to run_experiment_across_layers
    """

    train_input_numpy = train_input.detach().numpy()
    test_input_numpy = test_input.detach().numpy()
    train_labels_numpy = train_labels.detach().numpy()
    test_labels_numpy = test_labels.detach().numpy()

    model.fit(train_input_numpy, train_labels_numpy)

    accuracy = model.score(test_input_numpy, test_labels_numpy)

    return {"accuracy": accuracy}


def linear_probe_experiment(train_input, test_input, train_labels, test_labels):
    # initialize linear probe and run probe experiment
    lr = LogisticRegression(class_weight="balanced", verbose=1, max_iter=1000)
    return probe_experiment(train_input, test_input, train_labels, test_labels, lr)


def mlp_probe_experiment(train_input, test_input, train_labels, test_labels):
    # initialize an mlp probe and run probe experiment
    mlp = MLPClassifier(random_state=1, max_iter=1000, verbose=True)
    return probe_experiment(train_input, test_input, train_labels, test_labels, mlp)

