# Initialization

preliminaries

In [1]:
import torch, os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

In [40]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: mount failed

download tokenizer and model from hf

note: do not download and run both models at the same time, colab has some limitation and it is not guaranteed to work

In [5]:
# login with hf
from huggingface_hub import login
token = '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'
login(token)

In [6]:
model_id_b=f'google/t5gemma-b-b-ul2'
tokenizer_b = AutoTokenizer.from_pretrained(model_id_b)
model_b = AutoModelForSeq2SeqLM.from_pretrained(
    model_id_b,
    device_map="auto",
    dtype=torch.bfloat16,
)

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

In [None]:
model_id_2b=f'google/t5gemma-2b-2b-ul2'
tokenizer_2b = AutoTokenizer.from_pretrained(model_id_2b)
model_2b = AutoModelForSeq2SeqLM.from_pretrained(
    model_id_2b,
    device_map="auto",
    dtype=torch.bfloat16,
)

## **extracting the activations from the model**

We use mean pooling to obtain vector representations of sentences because SentenceBERT has shown that it works better than the CLS token. In our case, there is no CLS token, so this was not even an option. SentenceT5 has confirmed that mean pooling is the strategy that yields the best results for T5-based models when it is necessary to extract the sentence representation.

So we use this strategy.

In [10]:
# non batched
model_b.eval()

text = 'tell me something about the human brain'

inputs = tokenizer_b(text, return_tensors="pt").to(model_b.device)

start_token_id = tokenizer_b.bos_token_id
decoder_input_ids = torch.tensor([[start_token_id]], device=model_b.device)

with torch.no_grad():
    outputs = model_b(
        **inputs,
        decoder_input_ids=decoder_input_ids,
        output_hidden_states=True,
    )

encoder_hidden_states = torch.stack([e.cpu().squeeze(0) for e in outputs.encoder_hidden_states])
decoder_hidden_states = torch.stack([o.cpu().view(-1) for o in outputs.decoder_hidden_states])

#print(encoder_hidden_states.shape)
#print(decoder_hidden_states.shape)

print(len(outputs.encoder_hidden_states), len(outputs.decoder_hidden_states))
print(outputs.encoder_hidden_states[0].shape, outputs.decoder_hidden_states[0].shape)

13 13
torch.Size([1, 7, 768]) torch.Size([1, 1, 768])


In [None]:
# batched function
def extract_activations_df(base_df, model, tokenizer, text_column, BATCH_SIZE=1):
  df = base_df.copy()
  enc_results = {}
  dec_results = {}

  # mean pooling considering padding and using attention mask to set to 0 pad token representations
  def masked_mean_pooling(hidden_states, attention_mask):
      mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
      masked_embeddings = hidden_states * mask_expanded
      summed = torch.sum(masked_embeddings, dim=1)
      count = torch.clamp(mask_expanded.sum(1), min=1e-9)
      return summed / count

  model.eval()

  total_rows = len(df)

  print(f"Start processing {total_rows} sentences...")

  for i in tqdm(range(0, total_rows, BATCH_SIZE)):
      batch_texts = df[text_column][i : i + BATCH_SIZE].tolist()
      inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(model.device)
      current_batch_len = inputs.input_ids.shape[0]
      start_token_id = tokenizer.bos_token_id
      decoder_input_ids = torch.full((current_batch_len, 1), start_token_id, device=model.device)

      with torch.no_grad():
          outputs = model(
              **inputs,
              decoder_input_ids=decoder_input_ids,
              output_hidden_states=True,
          )

      # encoder extraction: final shape[Batch, Num_Layers, Hidden_Dim]
      attention_mask = inputs.attention_mask.cpu()
      batch_encoder_states = torch.stack([
          masked_mean_pooling(e.cpu(), attention_mask)
          for e in outputs.encoder_hidden_states
      ], dim=1).cpu().to(torch.float32).numpy()

      # decoder extraction: final shape[Batch, Num_Layers, Hidden_Dim]
      batch_decoder_states = torch.stack([
          o.cpu().squeeze(1) for o in outputs.decoder_hidden_states
      ], dim=1).cpu().to(torch.float32).numpy()

      num_enc_layers = batch_encoder_states.shape[1]
      num_dec_layers = batch_decoder_states.shape[1]

      # saving the activation results into the dictionaries
      for layer_idx in range(num_enc_layers):
          col_name = f'encoder_layer_{layer_idx+1}'
          if col_name not in enc_results: enc_results[col_name] = []
          vectors = list(batch_encoder_states[:, layer_idx, :])
          enc_results[col_name].extend(vectors)

      for layer_idx in range(num_dec_layers):
          col_name = f'decoder_layer_{layer_idx+1}'
          if col_name not in dec_results: dec_results[col_name] = []

          vectors = list(batch_decoder_states[:, layer_idx, :])
          dec_results[col_name].extend(vectors)

  print("Saving in the DataFrame...")
  for col_name, vectors in enc_results.items():
      df[col_name] = vectors

  for col_name, vectors in dec_results.items():
      df[col_name] = vectors

  print("Done! Columns added")
  return df

In [None]:
def save_activations_df(df, dataset_name, model_id):
  path = f'/content/drive/MyDrive/DTCS_datasets/{dataset_name}_{model_id.split('/')[1]}'
  print(f'Saving {dataset_name}_{model_id.split('/')[1]} to GDrive...')
  df.to_pickle(path)
  print(f'Saved {dataset_name}_{model_id.split("/")[1]}')

# Datasets

## True/False

In [None]:
!curl azariaa.com/Content/Datasets/true-false-dataset.zip > true-false-dataset.zip
!unzip "true-false-dataset.zip" -d "true-false-dataset"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 69243  100 69243    0     0   210k      0 --:--:-- --:--:-- --:--:--  211k
Archive:  true-false-dataset.zip
  inflating: true-false-dataset/publicDataset/animals_true_false.csv  
  inflating: true-false-dataset/publicDataset/cities_true_false.csv  
  inflating: true-false-dataset/publicDataset/companies_true_false.csv  
  inflating: true-false-dataset/publicDataset/elements_true_false.csv  
  inflating: true-false-dataset/publicDataset/facts_true_false.csv  
  inflating: true-false-dataset/publicDataset/generated_true_false.csv  
  inflating: true-false-dataset/publicDataset/inventions_true_false.csv  


In [None]:
# create a dataframe from the csv files
dir_path = '/content/true-false-dataset/publicDataset'
datasets_names = os.listdir(dir_path)
dfs = []

for dataset_name in datasets_names:
  path = f'{dir_path}/{dataset_name}'
  df = pd.read_csv(path)
  df.insert(loc=2, column='area', value=dataset_name.replace('_true_false.csv',''), allow_duplicates=True)
  dfs.append(df)

tf_df = pd.concat(dfs, ignore_index=True)
tf_df

Unnamed: 0,statement,label,area
0,The planet Uranus is tilted on its side.,1,facts
1,Sharks are sea creatures that have a reputatio...,1,facts
2,An adult human has 32 teeth.,1,facts
3,The smallest continent in the world is Australia.,1,facts
4,The Amazon River is the largest river in the w...,1,facts
...,...,...,...
6325,The capital of South Suda is Juba.,0,generated
6326,JAUBA is a town in the Central Equatorial Stat...,0,generated
6327,Jauba is located at the junction of the Equato...,0,generated
6328,JUABA is an administrative unit in the Equator...,0,generated


model_b

In [None]:
BATCH_SIZE = 128
text_column = 'statement'

activation_tf_df = extract_activations_df(tf_df, model_b, tokenizer_b, text_column, BATCH_SIZE)
save_activations_df(activation_tf_df, 'true-false', model_id_b)

Start processing 6330 sentences...


100%|██████████| 50/50 [00:18<00:00,  2.77it/s]

Saving in the DataFrame...
Done! Columns added





Unnamed: 0,statement,label,area,encoder_layer_1,encoder_layer_2,encoder_layer_3,encoder_layer_4,encoder_layer_5,encoder_layer_6,encoder_layer_7,...,decoder_layer_4,decoder_layer_5,decoder_layer_6,decoder_layer_7,decoder_layer_8,decoder_layer_9,decoder_layer_10,decoder_layer_11,decoder_layer_12,decoder_layer_13
0,The planet Uranus is tilted on its side.,1,facts,"[1.8242188, 0.014973958, 0.30143228, 0.2826606...","[0.056857638, 0.4626736, 0.16932508, -0.357638...","[-0.19216579, 0.4826389, 0.1438802, -0.4792751...","[-0.28027344, 0.14887153, -0.34320748, -0.1949...","[-0.09988064, 0.17556423, -0.5488281, -0.06087...","[0.35253906, 0.00043402778, 0.0029296875, -0.1...","[0.7221137, -0.21473524, -0.08993869, -0.50846...",...,"[-0.4765625, 0.39453125, 0.30273438, 0.5273437...","[-0.9296875, 0.51953125, -0.072265625, 0.82421...","[-1.25, 0.015136719, 0.15625, 0.62890625, -1.6...","[-1.203125, 0.53125, -0.18652344, 0.82421875, ...","[0.041015625, 0.048828125, 0.015625, 0.8945312...","[-0.11035156, -0.15820312, -0.45507812, 1.1171...","[-0.43359375, -0.30859375, -1.0, 0.5390625, -1...","[-1.296875, -0.1640625, -1.171875, 0.37109375,...","[-1.2890625, 0.29882812, -0.3828125, 0.1367187...","[28.25, -4.53125, 0.34179688, -0.94140625, 3.4..."
1,Sharks are sea creatures that have a reputatio...,1,facts,"[1.6576773, -0.23715445, -0.41706732, 0.110314...","[0.095853366, 0.023212139, -0.30155122, -0.453...","[-0.048753005, 0.015211839, -0.088604264, -0.4...","[-0.30742937, -0.17337741, -0.37474647, -0.137...","[-0.50946516, -0.27554086, -0.48152044, -0.087...","[-0.22385818, -0.6236478, -0.68073916, -0.0413...","[0.029897837, -0.97273135, -0.77659255, -0.074...",...,"[-0.5, 0.43554688, 0.3203125, 0.53515625, -0.4...","[-0.953125, 0.5625, -0.040527344, 0.83203125, ...","[-1.2890625, 0.049316406, 0.19335938, 0.625, -...","[-1.234375, 0.55859375, -0.16992188, 0.8242187...","[0.025390625, 0.107421875, 0.056640625, 0.8867...","[-0.123535156, -0.10986328, -0.37890625, 1.117...","[-0.44140625, -0.23046875, -0.96875, 0.5703125...","[-1.2890625, -0.14453125, -1.1953125, 0.449218...","[-1.34375, 0.3515625, -0.37890625, 0.14941406,...","[26.75, -5.125, -0.8046875, -1.7734375, -0.671..."
2,An adult human has 32 teeth.,1,facts,"[1.4643012, -0.42925346, 0.23860677, 0.4300130...","[0.15288629, -0.11577691, -0.13226996, -0.4715...","[-0.0070529515, -0.13682726, 0.06939019, -0.25...","[-0.2250434, -0.19845921, -0.10394965, -0.2877...","[0.0907118, -0.3184679, -0.30837673, -0.218532...","[0.4171007, -0.1802029, -0.044704862, -0.04058...","[0.4921875, -0.17274305, -0.69259983, -0.27365...",...,"[-0.515625, 0.40625, 0.32421875, 0.515625, -0....","[-0.9765625, 0.5390625, -0.036132812, 0.796875...","[-1.3046875, 0.01928711, 0.18359375, 0.6015625...","[-1.2578125, 0.53515625, -0.16894531, 0.804687...","[-0.0234375, 0.061523438, 0.015625, 0.88671875...","[-0.1484375, -0.1484375, -0.42578125, 1.09375,...","[-0.4453125, -0.296875, -0.9765625, 0.5390625,...","[-1.3203125, -0.19238281, -1.2109375, 0.419921...","[-1.421875, 0.3984375, -0.3359375, 0.033203125...","[23.0, -1.3671875, -3.125, 3.8125, -0.76953125..."
3,The smallest continent in the world is Australia.,1,facts,"[1.4700521, 0.0687934, 0.18645562, 0.42199367,...","[0.234375, 0.08821615, -0.08292643, -0.1408420...","[0.0703125, 0.09830729, 0.32109916, 0.00260416...","[-0.06287977, -0.2046441, -0.1779514, 0.145073...","[0.092447914, -0.024956597, -0.429579, 0.26736...","[-0.0059136283, -0.42719185, -0.3060981, 0.041...","[0.3373481, -0.5115017, -0.48676217, -0.137641...",...,"[-0.50390625, 0.390625, 0.3125, 0.49804688, -0...","[-0.95703125, 0.5234375, -0.052734375, 0.79687...","[-1.28125, 0.017822266, 0.18066406, 0.58984375...","[-1.234375, 0.5390625, -0.16601562, 0.7890625,...","[0.025390625, 0.06640625, 0.04296875, 0.855468...","[-0.13476562, -0.16015625, -0.38476562, 1.0625...","[-0.4609375, -0.30664062, -0.9765625, 0.507812...","[-1.3125, -0.22265625, -1.2109375, 0.38085938,...","[-1.421875, 0.34765625, -0.41796875, 0.0625, -...","[23.875, -3.890625, 1.15625, 0.22753906, 3.828..."
4,The Amazon River is the largest river in the w...,1,facts,"[1.1790866, -0.024188701, -0.20665565, 0.20469...","[0.020695614, 0.22814003, -0.35821062, -0.4532...","[-0.21146335, 0.20551945, -0.05577674, -0.3368...","[-0.34795672, 0.057016227, -0.48948318, 0.0582...","[-0.063777044, 0.27208534, -0.6057692, 0.11204...","[-0.09878305, 0.07016226, -0.4611253, 0.136944...","[0.29454628, 0.104191706, -0.9107572, -0.13690...",...,"[-0.4765625, 0.40234375, 0.28515625, 0.5351562...","[-0.9296875, 0.53515625, -0.08251953, 0.832031...","[-1.2578125, 0.042236328, 0.16113281, 0.640625...","[-1.203125, 0.55859375, -0.19921875, 0.8398437...","[0.05078125, 0.091796875, 0.021484375, 0.91015...","[-0.09667969, -0.12109375, -0.42578125, 1.125,...","[-0.42578125, -0.25195312, -0.98046875, 0.5781...","[-1.265625, -0.14550781, -1.1796875, 0.453125,...","[-1.34375, 0.375, -0.3671875, 0.06347656, -1.0...","[29.875, -1.140625, 2.640625, -0.88671875, 1.9..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6325,The capital of South Suda is Juba.,0,generated,"[1.6899414, 0.090063475, -0.34589845, 0.427441...","[0.05234375, 0.39042968, -0.37111816, -0.52656...","[-0.42539063, 0.3416992, 0.17148438, -0.596093...","[-0.26484376, 0.24560547, -0.24570313, -0.2957...","[-0.18076172, 0.16953126, -0.13066407, -0.1854...","[-0.23359375, -0.26933593, -0.14326172, -0.254...","[-0.070336916, -0.16132812, -0.7640625, -0.084...",...,"[-0.4921875, 0.4140625, 0.2890625, 0.55078125,...","[-0.9453125, 0.5390625, -0.059814453, 0.835937...","[-1.265625, 0.05419922, 0.18554688, 0.6171875,...","[-1.2109375, 0.578125, -0.15625, 0.81640625, -...","[0.052734375, 0.103515625, 0.064453125, 0.8789...","[-0.09277344, -0.09277344, -0.3984375, 1.10156...","[-0.40039062, -0.23339844, -0.953125, 0.542968...","[-1.234375, -0.14355469, -1.1875, 0.37890625, ...","[-1.2421875, 0.36914062, -0.375, -0.0004882812...","[23.75, 2.4375, 2.625, -2.6875, 1.8515625, 0.0..."
6326,JAUBA is a town in the Central Equatorial Stat...,0,generated,"[2.1651042, -0.20144857, -0.13483073, 0.192757...","[0.11595052, 0.2218099, -0.22102864, -0.216666...","[0.1608724, 0.16777344, 0.119596355, -0.168229...","[0.20273438, -0.024186198, -0.4561198, -0.1901...","[0.32958984, 0.072338864, -0.3652995, -0.05286...","[0.25797525, -0.10279948, -0.06845703, -0.1782...","[0.5825521, -0.08932292, -0.7499349, -0.008658...",...,"[-0.515625, 0.421875, 0.265625, 0.57421875, -0...","[-0.97265625, 0.55859375, -0.08300781, 0.87109...","[-1.3046875, 0.07714844, 0.16503906, 0.65625, ...","[-1.25, 0.6015625, -0.19335938, 0.859375, -1.5...","[0.017578125, 0.115234375, 0.041015625, 0.9179...","[-0.12451172, -0.06738281, -0.45117188, 1.1406...","[-0.4140625, -0.21289062, -1.0078125, 0.59375,...","[-1.2265625, -0.10839844, -1.234375, 0.4277343...","[-1.15625, 0.3984375, -0.44140625, -0.04125976...","[18.625, 1.0859375, 1.3984375, -2.265625, 5.40..."
6327,Jauba is located at the junction of the Equato...,0,generated,"[1.718099, -0.55709636, -0.23574218, 0.1610188...","[-0.0020833334, 0.10859375, -0.06705729, -0.31...","[-0.076627605, 0.045052085, 0.48854166, -0.231...","[-0.23190103, 0.011067708, -0.020996094, -0.01...","[-0.08331706, 0.096158855, -0.15289713, 0.2718...","[-0.123046875, -0.3981771, 0.43190104, 0.19238...","[0.47154948, -0.23108724, -0.2914388, 0.312630...",...,"[-0.48046875, 0.40625, 0.29101562, 0.5390625, ...","[-0.94140625, 0.53125, -0.072265625, 0.8359375...","[-1.265625, 0.045410156, 0.1640625, 0.6328125,...","[-1.203125, 0.5625, -0.18164062, 0.828125, -1....","[0.048828125, 0.087890625, 0.0390625, 0.894531...","[-0.109375, -0.10839844, -0.41015625, 1.109375...","[-0.42382812, -0.2421875, -0.96875, 0.546875, ...","[-1.2421875, -0.16210938, -1.203125, 0.4472656...","[-1.1875, 0.43554688, -0.3984375, 0.033691406,...","[27.5, 0.29296875, 2.734375, -2.828125, 4.9062..."
6328,JUABA is an administrative unit in the Equator...,0,generated,"[1.9664885, 0.03098016, -0.1511102, -0.0199038...","[0.064170435, 0.39941406, -0.09530479, -0.4312...","[-0.041272614, 0.21643709, 0.0989926, -0.45826...","[0.018117804, -0.05448191, -0.2706106, -0.3657...","[0.17895508, 0.056409333, -0.473787, -0.188065...","[0.30062705, -0.25406045, -0.07139186, -0.6759...","[0.7155119, -0.018451892, -0.6802786, -0.43929...",...,"[-0.515625, 0.41992188, 0.25585938, 0.54296875...","[-0.97265625, 0.55078125, -0.09863281, 0.83984...","[-1.3046875, 0.064453125, 0.14746094, 0.625, -...","[-1.25, 0.59375, -0.20410156, 0.8203125, -1.51...","[0.0234375, 0.1171875, 0.033203125, 0.890625, ...","[-0.122558594, -0.072753906, -0.44140625, 1.10...","[-0.40429688, -0.20898438, -0.9921875, 0.56640...","[-1.21875, -0.08105469, -1.21875, 0.40039062, ...","[-1.109375, 0.46875, -0.3671875, -0.031982422,...","[20.0, 1.21875, 1.0546875, -1.96875, 3.984375,..."


In [None]:
save_activations_df

model_2b

In [None]:
BATCH_SIZE = 16
text_column = 'statement'

activation_tf_df_2b = extract_activations_df(tf_df, model_2b, tokenizer_2b, text_column, BATCH_SIZE)
save_activations_df(activation_tf_df_2b, 'true-false', model_id_2b)

Start processing 6330 sentences...


100%|██████████| 396/396 [05:27<00:00,  1.21it/s]


Saving in the DataFrame...
Done! Columns added
Saving true-false_t5gemma-2b-2b-ul2 to GDrive...
Saved true-false_t5gemma-2b-2b-ul2


In [None]:
activation_tf_df_2b

Unnamed: 0,statement,label,area,encoder_layer_1,encoder_layer_2,encoder_layer_3,encoder_layer_4,encoder_layer_5,encoder_layer_6,encoder_layer_7,...,decoder_layer_18,decoder_layer_19,decoder_layer_20,decoder_layer_21,decoder_layer_22,decoder_layer_23,decoder_layer_24,decoder_layer_25,decoder_layer_26,decoder_layer_27
0,The planet Uranus is tilted on its side.,1,facts,"[0.9109158, -0.86013454, -0.80803764, 0.016710...","[0.7178819, 0.14713542, -0.4470486, 0.07248264...","[0.5036892, -0.16075304, -0.5388455, -0.543402...","[0.20155165, -0.37565103, -0.090277776, -1.196...","[0.2595486, 0.46126303, -0.3028429, -1.1968316...","[-0.46918404, 0.80251735, -0.31271702, -0.1119...","[0.32443577, 0.4969618, -0.017795138, -1.05642...",...,"[0.28515625, -0.25195312, -0.36914062, -0.4257...","[0.04296875, -0.19335938, 0.049072266, -0.6796...","[0.41992188, 0.45898438, -0.36914062, 0.128906...","[0.18359375, -0.765625, -0.7890625, -0.6210937...","[-0.20703125, -0.41796875, -0.4609375, -0.6992...","[-2.34375, 0.029296875, 1.0859375, -1.7109375,...","[-2.921875, 1.21875, 1.4609375, -5.15625, 1.21...","[-2.3125, 1.46875, 1.4609375, -5.9375, 1.07031...","[3.859375, 1.78125, 2.9375, -5.53125, 1.960937...","[1.84375, -0.55859375, 0.73046875, -0.71484375..."
1,Sharks are sea creatures that have a reputatio...,1,facts,"[0.14000526, -0.9311899, -0.14929257, 0.061598...","[0.15978065, -0.07527043, 0.421875, 0.96484375...","[0.007512019, -0.76171875, 0.5688101, 1.244140...","[0.40414664, -1.0458233, 0.5972806, 0.29710037...","[0.1711238, -0.18389423, -0.13431491, 0.029897...","[0.07932692, 0.096905045, -0.45718148, 0.58263...","[0.4341947, -0.68847656, -0.67277646, -0.39475...",...,"[0.265625, -0.27929688, -0.4453125, -0.4902343...","[0.05859375, -0.2421875, 0.02734375, -0.695312...","[0.43359375, 0.390625, -0.46289062, 0.12890625...","[0.20703125, -0.82421875, -0.8671875, -0.60156...","[-0.20214844, -0.56640625, -0.51953125, -0.699...","[-2.5, 0.025390625, 0.81640625, -1.6875, 0.25,...","[-2.890625, 1.0859375, 1.203125, -5.1875, 1.23...","[-2.484375, 1.4921875, 1.375, -5.65625, 1.3203...","[4.78125, 1.6875, 1.5, -5.84375, 1.5546875, -0...","[2.5625, -0.45898438, 0.14550781, -1.015625, 1..."
2,An adult human has 32 teeth.,1,facts,"[0.6130642, -0.2621528, 0.11461046, 0.4279514,...","[0.37239584, -0.113715276, 0.43337673, 0.17187...","[-0.037109375, 0.09483507, 0.8364258, 0.393229...","[0.56000435, 0.6281467, 0.6768663, -0.4625651,...","[0.016276041, 0.74609375, 0.55533856, -0.18945...","[-0.3997396, 0.05859375, 0.16666667, 0.6640625...","[0.32074654, -0.44059244, -0.062147353, -0.587...",...,"[0.26171875, -0.296875, -0.49609375, -0.423828...","[0.03515625, -0.24804688, -0.019897461, -0.695...","[0.4140625, 0.40625, -0.48046875, 0.08984375, ...","[0.19140625, -0.82421875, -0.8671875, -0.71093...","[-0.25390625, -0.5234375, -0.515625, -0.839843...","[-2.40625, 0.24804688, 0.921875, -1.984375, 0....","[-2.796875, 1.4140625, 1.234375, -5.625, 1.554...","[-2.40625, 1.6484375, 1.46875, -6.25, 1.46875,...","[2.625, 2.625, 1.78125, -6.21875, 1.890625, -0...","[2.296875, 0.86328125, 1.1875, -0.85546875, 2...."
3,The smallest continent in the world is Australia.,1,facts,"[0.07595486, -0.17903645, -0.22743055, 0.13628...","[0.47092015, 0.75130206, -0.32834202, -0.41514...","[0.30447048, 0.83029515, -0.29644096, -0.57855...","[0.46679688, 0.25260416, 0.074652776, -1.26540...","[0.5512153, 0.59939235, -0.69140625, -0.457465...","[-0.5390625, 1.1796875, -1.2220052, -0.0125868...","[0.30555555, 1.3854166, -0.9279514, -0.4619140...",...,"[0.328125, -0.2109375, -0.53515625, -0.546875,...","[0.0859375, -0.1796875, -0.10888672, -0.800781...","[0.484375, 0.44921875, -0.5859375, 0.0, -0.296...","[0.24609375, -0.76953125, -0.9765625, -0.75781...","[-0.234375, -0.5078125, -0.59375, -0.8203125, ...","[-2.625, 0.2265625, 0.8671875, -1.7421875, 0.2...","[-3.140625, 1.359375, 1.234375, -5.25, 1.14843...","[-2.65625, 1.765625, 1.234375, -5.84375, 1.101...","[2.125, 3.453125, 1.90625, -4.84375, 2.65625, ...","[-0.26367188, 0.48242188, 0.81640625, -1.32031..."
4,The Amazon River is the largest river in the w...,1,facts,"[0.88957334, -0.4628155, 0.68028843, 1.2866586...","[0.68659854, 0.5551758, 0.4826097, 1.0474759, ...","[0.60772234, 0.10832332, 0.094839245, 1.115835...","[1.0803787, 0.2701322, 0.36989182, -0.05709134...","[0.509991, 0.8996394, -0.21048678, 0.5871394, ...","[0.1789363, 1.2509015, -0.58759016, 0.94771636...","[0.74038464, 0.42822266, -0.092998795, -0.0021...",...,"[0.23828125, -0.29492188, -0.52734375, -0.4218...","[-0.02734375, -0.23828125, -0.053955078, -0.67...","[0.37890625, 0.39648438, -0.50390625, 0.152343...","[0.1328125, -0.83203125, -0.90625, -0.61328125...","[-0.28320312, -0.5234375, -0.5703125, -0.66796...","[-2.59375, 0.09765625, 1.03125, -1.484375, 0.3...","[-3.125, 1.2890625, 1.3125, -5.0625, 1.375, -0...","[-2.4375, 1.5703125, 1.484375, -5.84375, 1.359...","[4.25, 3.21875, 2.96875, -4.78125, 2.3125, 1.0...","[1.6015625, 0.21484375, 2.0, -0.6640625, 2.437..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6325,The capital of South Suda is Juba.,0,generated,"[0.40429688, -0.35566407, -0.37021485, -0.1209...","[0.5800781, 0.09475098, -0.8785156, 0.28378907...","[0.45546874, -1.2558105, -0.946875, 0.33125, -...","[0.54892576, -1.3199219, -0.265625, -0.1833007...","[0.71845704, -0.67148435, -0.73105466, -0.7600...","[0.34023437, -0.7270508, -0.32539064, -0.17246...","[0.33398438, 0.12109375, 0.38447267, -1.092968...",...,"[0.234375, -0.35546875, -0.46875, -0.5703125, ...","[0.0, -0.2890625, -0.0034179688, -0.80078125, ...","[0.40234375, 0.36914062, -0.4921875, 0.0390625...","[0.14453125, -0.84375, -0.8828125, -0.7109375,...","[-0.3125, -0.5625, -0.5703125, -0.765625, -1.9...","[-2.5625, 0.06640625, 0.8671875, -1.734375, 0....","[-3.109375, 1.359375, 1.046875, -5.4375, 1.109...","[-2.5625, 1.6171875, 1.15625, -6.0625, 1.14062...","[2.421875, 3.1875, 1.765625, -4.25, 2.015625, ...","[0.53515625, 0.375, 0.6015625, -0.84375, 1.648..."
6326,JAUBA is a town in the Central Equatorial Stat...,0,generated,"[0.37447917, -1.3463541, -0.7164062, 0.5007812...","[0.575, -0.5473633, -0.6655599, 0.4046224, 0.4...","[0.19700521, -0.43645832, -0.33815104, 0.05989...","[0.043554686, -0.569401, 0.228125, -0.20735677...","[-0.17265625, -0.033203125, -0.05266927, -0.13...","[-0.16523437, 0.056184895, -0.13515624, 0.7149...","[-0.23776041, -0.38190106, 0.44335938, -0.7375...",...,"[0.1484375, -0.421875, -0.53515625, -0.4667968...","[-0.0625, -0.33398438, -0.03466797, -0.8046875...","[0.3515625, 0.3046875, -0.52734375, 0.03125, -...","[0.125, -0.9375, -0.9140625, -0.73046875, -1.7...","[-0.27929688, -0.671875, -0.5859375, -0.804687...","[-2.609375, -0.015625, 0.9921875, -1.65625, 0....","[-3.015625, 1.1484375, 1.03125, -5.28125, 1.15...","[-2.546875, 1.390625, 1.140625, -6.03125, 1.28...","[3.234375, 2.375, 2.6875, -4.5, 2.859375, 2.43...","[0.66015625, -0.6796875, 1.6171875, -0.2578125..."
6327,Jauba is located at the junction of the Equato...,0,generated,"[1.1541016, -1.3690104, -0.52441406, 0.2944010...","[1.4063314, -0.5888021, -0.7594401, 0.26158854...","[1.4248698, -0.5119141, -1.0794597, -0.2738932...","[0.6610677, -0.67815757, -0.3191406, -0.741666...","[0.39635417, -0.110416666, -0.72415364, -0.782...","[0.5410807, 0.1516927, -0.7972005, -0.04277343...","[0.5263021, 0.44427082, -0.90052086, -1.207812...",...,"[0.19140625, -0.36914062, -0.44921875, -0.3593...","[-0.01953125, -0.30273438, 0.037109375, -0.667...","[0.375, 0.328125, -0.42773438, 0.1640625, -0.3...","[0.11328125, -0.9140625, -0.83203125, -0.60156...","[-0.2578125, -0.6484375, -0.5625, -0.60546875,...","[-2.5, -0.001953125, 1.015625, -1.5390625, 0.1...","[-3.125, 1.2890625, 1.171875, -5.0625, 1.01562...","[-2.609375, 1.484375, 1.421875, -5.78125, 1.14...","[2.9375, 2.828125, 3.28125, -4.6875, 1.6875, 0...","[0.88671875, -0.41796875, 1.734375, -0.9414062..."
6328,JUABA is an administrative unit in the Equator...,0,generated,"[0.8184622, -1.1554276, -0.50945723, -0.191817...","[1.1214535, -0.7328331, -0.35541734, -0.061215...","[1.0067846, -0.55828536, -0.22973633, -0.07092...","[0.3385074, -0.71026933, 0.22805305, -0.584703...","[0.063733555, -0.18678042, 0.46720806, -0.5142...","[-0.17259458, 0.16324013, -0.3051501, 0.241570...","[-0.67516446, -0.27400288, 0.046772204, -0.785...",...,"[0.140625, -0.44726562, -0.5234375, -0.4667968...","[-0.046875, -0.375, -0.0107421875, -0.78515625...","[0.35546875, 0.26757812, -0.4921875, 0.0234375...","[0.15234375, -0.95703125, -0.89453125, -0.7734...","[-0.26367188, -0.69140625, -0.55078125, -0.835...","[-2.625, 0.01953125, 0.95703125, -1.7421875, 0...","[-3.125, 1.109375, 0.953125, -5.34375, 1.21875...","[-2.796875, 1.296875, 0.8515625, -6.0, 1.38281...","[3.96875, 3.28125, 1.90625, -5.84375, 2.78125,...","[0.8671875, -0.25585938, 0.51171875, -0.40625,..."


## CoLA

In [None]:
!wget https://nyu-mll.github.io/CoLA/cola_public_1.1.zip
!unzip cola_public_1.1.zip

--2025-11-22 17:44:18--  https://nyu-mll.github.io/CoLA/cola_public_1.1.zip
Resolving nyu-mll.github.io (nyu-mll.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to nyu-mll.github.io (nyu-mll.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 255330 (249K) [application/x-zip-compressed]
Saving to: ‘cola_public_1.1.zip’


2025-11-22 17:44:18 (13.9 MB/s) - ‘cola_public_1.1.zip’ saved [255330/255330]

Archive:  cola_public_1.1.zip
   creating: cola_public/
  inflating: cola_public/README      
   creating: cola_public/tokenized/
  inflating: cola_public/tokenized/in_domain_dev.tsv  
  inflating: cola_public/tokenized/in_domain_train.tsv  
  inflating: cola_public/tokenized/out_of_domain_dev.tsv  
   creating: cola_public/raw/
  inflating: cola_public/raw/in_domain_dev.tsv  
  inflating: cola_public/raw/in_domain_train.tsv  
  inflating: cola_public/raw/out_of_domain_dev.tsv  


In [None]:
path = '/content/cola_public/raw/'
cola_files = os.listdir(path) # contiene ['out_of_domain_dev.tsv', 'in_domain_train.tsv', 'in_domain_dev.tsv']
dfs = []

for cf in cola_files:
  df = pd.read_csv(f'{path}{cf}', delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence'])
  df.drop(columns=['sentence_source', 'label_notes'], inplace=True)
  df.insert(loc=0, column='source', value=cf.split('.')[0], allow_duplicates=True)
  dfs.append(df)

cola_df = pd.concat(dfs, ignore_index=True)
cola_df

Unnamed: 0,source,label,sentence
0,in_domain_dev,1,The sailors rode the breeze clear of the rocks.
1,in_domain_dev,1,The weights made the rope stretch over the pul...
2,in_domain_dev,1,The mechanical doll wriggled itself loose.
3,in_domain_dev,1,"If you had eaten more, you would want less."
4,in_domain_dev,0,"As you eat the most, you want the least."
...,...,...,...
9589,in_domain_train,0,Poseidon appears to own a dragon
9590,in_domain_train,0,Digitize is my happiest memory
9591,in_domain_train,1,It is easy to slay the Gorgon.
9592,in_domain_train,1,I had the strangest feeling that I knew you.


model_b

In [None]:
BATCH_SIZE = 64
text_column = 'source'

activation_cola_df = extract_activations_df(cola_df, model_b, tokenizer_b, text_column, BATCH_SIZE)
save_activations_df(activation_cola_df, 'cola', model_id_b)

Start processing 9594 sentences...


100%|██████████| 150/150 [00:42<00:00,  3.49it/s]


Saving in the DataFrame...
Done! Columns added


Unnamed: 0,source,label,sentence,encoder_layer_1,encoder_layer_2,encoder_layer_3,encoder_layer_4,encoder_layer_5,encoder_layer_6,encoder_layer_7,...,decoder_layer_4,decoder_layer_5,decoder_layer_6,decoder_layer_7,decoder_layer_8,decoder_layer_9,decoder_layer_10,decoder_layer_11,decoder_layer_12,decoder_layer_13
0,in_domain_dev,1,The sailors rode the breeze clear of the rocks.,"[2.278125, -0.20976563, -0.38027343, 0.4071289...","[0.09804688, 0.02421875, -0.35742188, -0.31718...","[0.03046875, 0.5109375, 0.40625, -0.49726564, ...","[-0.2484375, -0.31621093, -0.008203125, -0.635...","[-0.1421875, -0.0053710938, -0.2444336, -0.475...","[0.17460938, -0.7871094, -0.0076171877, -0.515...","[0.043164063, -0.6929687, 0.51225585, -0.64609...",...,"[-0.4921875, 0.3984375, 0.296875, 0.49609375, ...","[-0.95703125, 0.5078125, -0.04663086, 0.773437...","[-1.2890625, 0.0087890625, 0.17578125, 0.56640...","[-1.25, 0.53125, -0.15820312, 0.76171875, -1.5...","[-0.00390625, 0.072265625, 0.03125, 0.828125, ...","[-0.13476562, -0.12402344, -0.43945312, 1.0468...","[-0.47265625, -0.27148438, -1.0, 0.47265625, -...","[-1.359375, -0.15625, -1.203125, 0.33789062, -...","[-1.53125, 0.56640625, -0.5234375, 0.09667969,...","[-12.25, 0.47070312, 1.78125, 0.78515625, -1.7..."
1,in_domain_dev,1,The weights made the rope stretch over the pul...,"[2.278125, -0.20976563, -0.38027343, 0.4071289...","[0.09804688, 0.02421875, -0.35742188, -0.31718...","[0.03046875, 0.5109375, 0.40625, -0.49726564, ...","[-0.2484375, -0.31621093, -0.008203125, -0.635...","[-0.1421875, -0.0053710938, -0.2444336, -0.475...","[0.17460938, -0.7871094, -0.0076171877, -0.515...","[0.043164063, -0.6929687, 0.51225585, -0.64609...",...,"[-0.4921875, 0.3984375, 0.296875, 0.49609375, ...","[-0.95703125, 0.5078125, -0.04663086, 0.773437...","[-1.2890625, 0.0087890625, 0.17578125, 0.56640...","[-1.25, 0.53125, -0.15820312, 0.76171875, -1.5...","[-0.00390625, 0.072265625, 0.03125, 0.828125, ...","[-0.13476562, -0.12402344, -0.43945312, 1.0468...","[-0.47265625, -0.27148438, -1.0, 0.47265625, -...","[-1.359375, -0.15625, -1.203125, 0.33789062, -...","[-1.53125, 0.56640625, -0.5234375, 0.09667969,...","[-12.25, 0.47070312, 1.78125, 0.78515625, -1.7..."
2,in_domain_dev,1,The mechanical doll wriggled itself loose.,"[2.278125, -0.20976563, -0.38027343, 0.4071289...","[0.09804688, 0.02421875, -0.35742188, -0.31718...","[0.03046875, 0.5109375, 0.40625, -0.49726564, ...","[-0.2484375, -0.31621093, -0.008203125, -0.635...","[-0.1421875, -0.0053710938, -0.2444336, -0.475...","[0.17460938, -0.7871094, -0.0076171877, -0.515...","[0.043164063, -0.6929687, 0.51225585, -0.64609...",...,"[-0.4921875, 0.3984375, 0.296875, 0.49609375, ...","[-0.95703125, 0.5078125, -0.04663086, 0.773437...","[-1.2890625, 0.0087890625, 0.17578125, 0.56640...","[-1.25, 0.53125, -0.15820312, 0.76171875, -1.5...","[-0.00390625, 0.072265625, 0.03125, 0.828125, ...","[-0.13476562, -0.12402344, -0.43945312, 1.0468...","[-0.47265625, -0.27148438, -1.0, 0.47265625, -...","[-1.359375, -0.15625, -1.203125, 0.33789062, -...","[-1.53125, 0.56640625, -0.5234375, 0.09667969,...","[-12.25, 0.47070312, 1.78125, 0.78515625, -1.7..."
3,in_domain_dev,1,"If you had eaten more, you would want less.","[2.278125, -0.20976563, -0.38027343, 0.4071289...","[0.09804688, 0.02421875, -0.35742188, -0.31718...","[0.03046875, 0.5109375, 0.40625, -0.49726564, ...","[-0.2484375, -0.31621093, -0.008203125, -0.635...","[-0.1421875, -0.0053710938, -0.2444336, -0.475...","[0.17460938, -0.7871094, -0.0076171877, -0.515...","[0.043164063, -0.6929687, 0.51225585, -0.64609...",...,"[-0.4921875, 0.3984375, 0.296875, 0.49609375, ...","[-0.95703125, 0.5078125, -0.04663086, 0.773437...","[-1.2890625, 0.0087890625, 0.17578125, 0.56640...","[-1.25, 0.53125, -0.15820312, 0.76171875, -1.5...","[-0.00390625, 0.072265625, 0.03125, 0.828125, ...","[-0.13476562, -0.12402344, -0.43945312, 1.0468...","[-0.47265625, -0.27148438, -1.0, 0.47265625, -...","[-1.359375, -0.15625, -1.203125, 0.33789062, -...","[-1.53125, 0.56640625, -0.5234375, 0.09667969,...","[-12.25, 0.47070312, 1.78125, 0.78515625, -1.7..."
4,in_domain_dev,0,"As you eat the most, you want the least.","[2.278125, -0.20976563, -0.38027343, 0.4071289...","[0.09804688, 0.02421875, -0.35742188, -0.31718...","[0.03046875, 0.5109375, 0.40625, -0.49726564, ...","[-0.2484375, -0.31621093, -0.008203125, -0.635...","[-0.1421875, -0.0053710938, -0.2444336, -0.475...","[0.17460938, -0.7871094, -0.0076171877, -0.515...","[0.043164063, -0.6929687, 0.51225585, -0.64609...",...,"[-0.4921875, 0.3984375, 0.296875, 0.49609375, ...","[-0.95703125, 0.5078125, -0.04663086, 0.773437...","[-1.2890625, 0.0087890625, 0.17578125, 0.56640...","[-1.25, 0.53125, -0.15820312, 0.76171875, -1.5...","[-0.00390625, 0.072265625, 0.03125, 0.828125, ...","[-0.13476562, -0.12402344, -0.43945312, 1.0468...","[-0.47265625, -0.27148438, -1.0, 0.47265625, -...","[-1.359375, -0.15625, -1.203125, 0.33789062, -...","[-1.53125, 0.56640625, -0.5234375, 0.09667969,...","[-12.25, 0.47070312, 1.78125, 0.78515625, -1.7..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9589,in_domain_train,0,Poseidon appears to own a dragon,"[2.30625, -0.10625, -0.0109375, 0.54179686, 0....","[0.14335938, -0.01953125, -0.23359375, -0.0933...","[0.28828126, 0.33359376, 0.46914062, -0.246093...","[-0.16015625, -0.47460938, 0.33398438, -0.1605...","[-0.14414063, -0.23603515, 0.025341798, 0.0015...","[0.2859375, -0.63789064, 0.16025391, 0.0443237...","[0.14296874, -0.7, 0.39082032, -0.10507812, -0...",...,"[-0.3203125, 0.35546875, 0.26171875, 0.3339843...","[-0.79296875, 0.42578125, -0.0625, 0.59375, -1...","[-1.0859375, -0.16601562, 0.15722656, 0.396484...","[-1.125, 0.3828125, -0.13378906, 0.6328125, -1...","[0.061523438, -0.12695312, 0.060546875, 0.7851...","[0.0043945312, -0.24023438, -0.43554688, 0.988...","[-0.45507812, -0.53515625, -0.81640625, 0.4648...","[-1.46875, -0.48046875, -1.0390625, 0.26953125...","[-2.0625, 0.35546875, -0.5234375, 0.19921875, ...","[-15.375, 0.26953125, -3.09375, 0.7890625, -3...."
9590,in_domain_train,0,Digitize is my happiest memory,"[2.30625, -0.10625, -0.0109375, 0.54179686, 0....","[0.14335938, -0.01953125, -0.23359375, -0.0933...","[0.28828126, 0.33359376, 0.46914062, -0.246093...","[-0.16015625, -0.47460938, 0.33398438, -0.1605...","[-0.14414063, -0.23603515, 0.025341798, 0.0015...","[0.2859375, -0.63789064, 0.16025391, 0.0443237...","[0.14296874, -0.7, 0.39082032, -0.10507812, -0...",...,"[-0.3203125, 0.35546875, 0.26171875, 0.3339843...","[-0.79296875, 0.42578125, -0.0625, 0.59375, -1...","[-1.0859375, -0.16601562, 0.15722656, 0.396484...","[-1.125, 0.3828125, -0.13378906, 0.6328125, -1...","[0.061523438, -0.12695312, 0.060546875, 0.7851...","[0.0043945312, -0.24023438, -0.43554688, 0.988...","[-0.45507812, -0.53515625, -0.81640625, 0.4648...","[-1.46875, -0.48046875, -1.0390625, 0.26953125...","[-2.0625, 0.35546875, -0.5234375, 0.19921875, ...","[-15.375, 0.26953125, -3.09375, 0.7890625, -3...."
9591,in_domain_train,1,It is easy to slay the Gorgon.,"[2.30625, -0.10625, -0.0109375, 0.54179686, 0....","[0.14335938, -0.01953125, -0.23359375, -0.0933...","[0.28828126, 0.33359376, 0.46914062, -0.246093...","[-0.16015625, -0.47460938, 0.33398438, -0.1605...","[-0.14414063, -0.23603515, 0.025341798, 0.0015...","[0.2859375, -0.63789064, 0.16025391, 0.0443237...","[0.14296874, -0.7, 0.39082032, -0.10507812, -0...",...,"[-0.3203125, 0.35546875, 0.26171875, 0.3339843...","[-0.79296875, 0.42578125, -0.0625, 0.59375, -1...","[-1.0859375, -0.16601562, 0.15722656, 0.396484...","[-1.125, 0.3828125, -0.13378906, 0.6328125, -1...","[0.061523438, -0.12695312, 0.060546875, 0.7851...","[0.0043945312, -0.24023438, -0.43554688, 0.988...","[-0.45507812, -0.53515625, -0.81640625, 0.4648...","[-1.46875, -0.48046875, -1.0390625, 0.26953125...","[-2.0625, 0.35546875, -0.5234375, 0.19921875, ...","[-15.375, 0.26953125, -3.09375, 0.7890625, -3...."
9592,in_domain_train,1,I had the strangest feeling that I knew you.,"[2.30625, -0.10625, -0.0109375, 0.54179686, 0....","[0.14335938, -0.01953125, -0.23359375, -0.0933...","[0.28828126, 0.33359376, 0.46914062, -0.246093...","[-0.16015625, -0.47460938, 0.33398438, -0.1605...","[-0.14414063, -0.23603515, 0.025341798, 0.0015...","[0.2859375, -0.63789064, 0.16025391, 0.0443237...","[0.14296874, -0.7, 0.39082032, -0.10507812, -0...",...,"[-0.3203125, 0.35546875, 0.26171875, 0.3339843...","[-0.79296875, 0.42578125, -0.0625, 0.59375, -1...","[-1.0859375, -0.16601562, 0.15722656, 0.396484...","[-1.125, 0.3828125, -0.13378906, 0.6328125, -1...","[0.061523438, -0.12695312, 0.060546875, 0.7851...","[0.0043945312, -0.24023438, -0.43554688, 0.988...","[-0.45507812, -0.53515625, -0.81640625, 0.4648...","[-1.46875, -0.48046875, -1.0390625, 0.26953125...","[-2.0625, 0.35546875, -0.5234375, 0.19921875, ...","[-15.375, 0.26953125, -3.09375, 0.7890625, -3...."


In [None]:
activation_cola_df

model_2b

In [None]:
BATCH_SIZE = 64
text_column = 'source'

activation_cola_df_2b = extract_activations_df(cola_df, model_2b, tokenizer_2b, text_column, BATCH_SIZE)
save_activations_df(activation_cola_df_2b, 'cola', model_id_2b)

In [None]:
save_activations_df(activation_cola_df_2b, 'cola', model_id_2b)

Saving cola_t5gemma-2b-2b-ul2 to GDrive...
Saved cola_t5gemma-2b-2b-ul2


In [None]:
activation_cola_df_2b

Unnamed: 0,source,label,sentence,encoder_layer_1,encoder_layer_2,encoder_layer_3,encoder_layer_4,encoder_layer_5,encoder_layer_6,encoder_layer_7,...,decoder_layer_18,decoder_layer_19,decoder_layer_20,decoder_layer_21,decoder_layer_22,decoder_layer_23,decoder_layer_24,decoder_layer_25,decoder_layer_26,decoder_layer_27
0,in_domain_dev,1,The sailors rode the breeze clear of the rocks.,"[-0.20664063, -0.33125, -2.0703125, 0.18125, 1...","[-0.21699218, -0.6515625, -0.6003906, -0.14765...","[0.89140624, 0.5, 0.6953125, 0.39140624, -0.44...","[0.85, -0.6933594, 0.91796875, 0.70703125, -0....","[0.4421875, -0.13125, 0.6152344, -0.10585938, ...","[0.34765625, -0.840625, 0.6154297, -0.35234374...","[0.76894534, -0.503125, 0.0041015623, 0.997656...",...,"[0.0234375, -0.32421875, -0.10546875, -0.32421...","[-0.17578125, -0.12109375, 0.22851562, -0.5898...","[0.27734375, 0.53515625, -0.13085938, 0.25, -0...","[0.05859375, -0.515625, -0.796875, -0.515625, ...","[-0.44921875, -0.0703125, -0.30859375, -0.7812...","[-3.09375, 1.0546875, 1.6640625, -2.09375, 0.2...","[-3.65625, 2.609375, 1.6328125, -5.9375, 1.351...","[-2.9375, 2.71875, 1.9140625, -6.1875, 0.99218...","[2.015625, 1.765625, 2.953125, -6.65625, 0.298...","[0.25976562, -0.36328125, 1.0859375, -0.277343..."
1,in_domain_dev,1,The weights made the rope stretch over the pul...,"[-0.20664063, -0.33125, -2.0703125, 0.18125, 1...","[-0.21699218, -0.6515625, -0.6003906, -0.14765...","[0.89140624, 0.5, 0.6953125, 0.39140624, -0.44...","[0.85, -0.6933594, 0.91796875, 0.70703125, -0....","[0.4421875, -0.13125, 0.6152344, -0.10585938, ...","[0.34765625, -0.840625, 0.6154297, -0.35234374...","[0.76894534, -0.503125, 0.0041015623, 0.997656...",...,"[0.0234375, -0.32421875, -0.10546875, -0.32421...","[-0.17578125, -0.12109375, 0.22851562, -0.5898...","[0.27734375, 0.53515625, -0.13085938, 0.25, -0...","[0.05859375, -0.515625, -0.796875, -0.515625, ...","[-0.44921875, -0.0703125, -0.30859375, -0.7812...","[-3.09375, 1.0546875, 1.6640625, -2.09375, 0.2...","[-3.65625, 2.609375, 1.6328125, -5.9375, 1.351...","[-2.9375, 2.71875, 1.9140625, -6.1875, 0.99218...","[2.015625, 1.765625, 2.953125, -6.65625, 0.298...","[0.25976562, -0.36328125, 1.0859375, -0.277343..."
2,in_domain_dev,1,The mechanical doll wriggled itself loose.,"[-0.20664063, -0.33125, -2.0703125, 0.18125, 1...","[-0.21699218, -0.6515625, -0.6003906, -0.14765...","[0.89140624, 0.5, 0.6953125, 0.39140624, -0.44...","[0.85, -0.6933594, 0.91796875, 0.70703125, -0....","[0.4421875, -0.13125, 0.6152344, -0.10585938, ...","[0.34765625, -0.840625, 0.6154297, -0.35234374...","[0.76894534, -0.503125, 0.0041015623, 0.997656...",...,"[0.0234375, -0.32421875, -0.10546875, -0.32421...","[-0.17578125, -0.12109375, 0.22851562, -0.5898...","[0.27734375, 0.53515625, -0.13085938, 0.25, -0...","[0.05859375, -0.515625, -0.796875, -0.515625, ...","[-0.44921875, -0.0703125, -0.30859375, -0.7812...","[-3.09375, 1.0546875, 1.6640625, -2.09375, 0.2...","[-3.65625, 2.609375, 1.6328125, -5.9375, 1.351...","[-2.9375, 2.71875, 1.9140625, -6.1875, 0.99218...","[2.015625, 1.765625, 2.953125, -6.65625, 0.298...","[0.25976562, -0.36328125, 1.0859375, -0.277343..."
3,in_domain_dev,1,"If you had eaten more, you would want less.","[-0.20664063, -0.33125, -2.0703125, 0.18125, 1...","[-0.21699218, -0.6515625, -0.6003906, -0.14765...","[0.89140624, 0.5, 0.6953125, 0.39140624, -0.44...","[0.85, -0.6933594, 0.91796875, 0.70703125, -0....","[0.4421875, -0.13125, 0.6152344, -0.10585938, ...","[0.34765625, -0.840625, 0.6154297, -0.35234374...","[0.76894534, -0.503125, 0.0041015623, 0.997656...",...,"[0.0234375, -0.32421875, -0.10546875, -0.32421...","[-0.17578125, -0.12109375, 0.22851562, -0.5898...","[0.27734375, 0.53515625, -0.13085938, 0.25, -0...","[0.05859375, -0.515625, -0.796875, -0.515625, ...","[-0.44921875, -0.0703125, -0.30859375, -0.7812...","[-3.09375, 1.0546875, 1.6640625, -2.09375, 0.2...","[-3.65625, 2.609375, 1.6328125, -5.9375, 1.351...","[-2.9375, 2.71875, 1.9140625, -6.1875, 0.99218...","[2.015625, 1.765625, 2.953125, -6.65625, 0.298...","[0.25976562, -0.36328125, 1.0859375, -0.277343..."
4,in_domain_dev,0,"As you eat the most, you want the least.","[-0.20664063, -0.33125, -2.0703125, 0.18125, 1...","[-0.21699218, -0.6515625, -0.6003906, -0.14765...","[0.89140624, 0.5, 0.6953125, 0.39140624, -0.44...","[0.85, -0.6933594, 0.91796875, 0.70703125, -0....","[0.4421875, -0.13125, 0.6152344, -0.10585938, ...","[0.34765625, -0.840625, 0.6154297, -0.35234374...","[0.76894534, -0.503125, 0.0041015623, 0.997656...",...,"[0.0234375, -0.32421875, -0.10546875, -0.32421...","[-0.17578125, -0.12109375, 0.22851562, -0.5898...","[0.27734375, 0.53515625, -0.13085938, 0.25, -0...","[0.05859375, -0.515625, -0.796875, -0.515625, ...","[-0.44921875, -0.0703125, -0.30859375, -0.7812...","[-3.09375, 1.0546875, 1.6640625, -2.09375, 0.2...","[-3.65625, 2.609375, 1.6328125, -5.9375, 1.351...","[-2.9375, 2.71875, 1.9140625, -6.1875, 0.99218...","[2.015625, 1.765625, 2.953125, -6.65625, 0.298...","[0.25976562, -0.36328125, 1.0859375, -0.277343..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9589,in_domain_train,0,Poseidon appears to own a dragon,"[-0.2878906, -1.54375, -1.6765625, 0.3578125, ...","[-0.065820314, -1.0578125, -0.6777344, 0.11210...","[1.0828125, 0.27578124, 0.8769531, 1.171875, -...","[1.2980468, -0.1796875, 0.7171875, 1.3070313, ...","[0.6660156, 0.378125, 0.23671874, 0.98183596, ...","[1.1953125, -0.7453125, -0.042382814, 0.404687...","[1.3224609, -0.63671875, -0.34140626, 1.260937...",...,"[0.140625, -0.13476562, -0.29492188, -0.464843...","[-0.0859375, -0.12109375, 0.060302734, -0.7031...","[0.31445312, 0.50390625, -0.40625, 0.12109375,...","[0.0703125, -0.67578125, -0.80078125, -0.65234...","[-0.37109375, -0.33984375, -0.44140625, -0.796...","[-2.65625, 0.4453125, 1.03125, -1.703125, 0.26...","[-2.9375, 1.6484375, 1.0, -5.5, 0.984375, -0.4...","[-2.296875, 1.8125, 1.0546875, -6.03125, 0.593...","[2.46875, 1.5625, 1.8515625, -5.625, -0.162109...","[0.64453125, -1.2421875, 1.0078125, 0.9375, -0..."
9590,in_domain_train,0,Digitize is my happiest memory,"[-0.2878906, -1.54375, -1.6765625, 0.3578125, ...","[-0.065820314, -1.0578125, -0.6777344, 0.11210...","[1.0828125, 0.27578124, 0.8769531, 1.171875, -...","[1.2980468, -0.1796875, 0.7171875, 1.3070313, ...","[0.6660156, 0.378125, 0.23671874, 0.98183596, ...","[1.1953125, -0.7453125, -0.042382814, 0.404687...","[1.3224609, -0.63671875, -0.34140626, 1.260937...",...,"[0.140625, -0.13476562, -0.29492188, -0.464843...","[-0.0859375, -0.12109375, 0.060302734, -0.7031...","[0.31445312, 0.50390625, -0.40625, 0.12109375,...","[0.0703125, -0.67578125, -0.80078125, -0.65234...","[-0.37109375, -0.33984375, -0.44140625, -0.796...","[-2.65625, 0.4453125, 1.03125, -1.703125, 0.26...","[-2.9375, 1.6484375, 1.0, -5.5, 0.984375, -0.4...","[-2.296875, 1.8125, 1.0546875, -6.03125, 0.593...","[2.46875, 1.5625, 1.8515625, -5.625, -0.162109...","[0.64453125, -1.2421875, 1.0078125, 0.9375, -0..."
9591,in_domain_train,1,It is easy to slay the Gorgon.,"[-0.2878906, -1.54375, -1.6765625, 0.3578125, ...","[-0.065820314, -1.0578125, -0.6777344, 0.11210...","[1.0828125, 0.27578124, 0.8769531, 1.171875, -...","[1.2980468, -0.1796875, 0.7171875, 1.3070313, ...","[0.6660156, 0.378125, 0.23671874, 0.98183596, ...","[1.1953125, -0.7453125, -0.042382814, 0.404687...","[1.3224609, -0.63671875, -0.34140626, 1.260937...",...,"[0.140625, -0.13476562, -0.29492188, -0.464843...","[-0.0859375, -0.12109375, 0.060302734, -0.7031...","[0.31445312, 0.50390625, -0.40625, 0.12109375,...","[0.0703125, -0.67578125, -0.80078125, -0.65234...","[-0.37109375, -0.33984375, -0.44140625, -0.796...","[-2.65625, 0.4453125, 1.03125, -1.703125, 0.26...","[-2.9375, 1.6484375, 1.0, -5.5, 0.984375, -0.4...","[-2.296875, 1.8125, 1.0546875, -6.03125, 0.593...","[2.46875, 1.5625, 1.8515625, -5.625, -0.162109...","[0.64453125, -1.2421875, 1.0078125, 0.9375, -0..."
9592,in_domain_train,1,I had the strangest feeling that I knew you.,"[-0.2878906, -1.54375, -1.6765625, 0.3578125, ...","[-0.065820314, -1.0578125, -0.6777344, 0.11210...","[1.0828125, 0.27578124, 0.8769531, 1.171875, -...","[1.2980468, -0.1796875, 0.7171875, 1.3070313, ...","[0.6660156, 0.378125, 0.23671874, 0.98183596, ...","[1.1953125, -0.7453125, -0.042382814, 0.404687...","[1.3224609, -0.63671875, -0.34140626, 1.260937...",...,"[0.140625, -0.13476562, -0.29492188, -0.464843...","[-0.0859375, -0.12109375, 0.060302734, -0.7031...","[0.31445312, 0.50390625, -0.40625, 0.12109375,...","[0.0703125, -0.67578125, -0.80078125, -0.65234...","[-0.37109375, -0.33984375, -0.44140625, -0.796...","[-2.65625, 0.4453125, 1.03125, -1.703125, 0.26...","[-2.9375, 1.6484375, 1.0, -5.5, 0.984375, -0.4...","[-2.296875, 1.8125, 1.0546875, -6.03125, 0.593...","[2.46875, 1.5625, 1.8515625, -5.625, -0.162109...","[0.64453125, -1.2421875, 1.0078125, 0.9375, -0..."


## UD_English-EWT

In [11]:
!pip install conllu
!wget https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/refs/heads/master/en_ewt-ud-train.conllu

Collecting conllu
  Downloading conllu-6.0.0-py3-none-any.whl.metadata (21 kB)
Downloading conllu-6.0.0-py3-none-any.whl (16 kB)
Installing collected packages: conllu
Successfully installed conllu-6.0.0
--2025-11-24 10:22:37--  https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/refs/heads/master/en_ewt-ud-train.conllu
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15029817 (14M) [text/plain]
Saving to: ‘en_ewt-ud-train.conllu’


2025-11-24 10:22:37 (173 MB/s) - ‘en_ewt-ud-train.conllu’ saved [15029817/15029817]



In [14]:
from conllu import parse_incr

def load_conllu(path):
    with open(path, encoding="utf-8") as f:
        for tokenlist in parse_incr(f):
            yield {
                "text": tokenlist.metadata.get("text", ""),
                "tokens": [t["form"] for t in tokenlist],
                "token_id": [t["id"] for t in tokenlist],
                "upos": [t["upostag"] for t in tokenlist],
                #"xpos": [t["xpostag"] for t in tokenlist],
            }

train = list(load_conllu("en_ewt-ud-train.conllu"))
#dev = list(load_conllu("UD_English-EWT/en_ewt-ud-dev.conllu"))
#test = list(load_conllu("UD_English-EWT/en_ewt-ud-test.conllu"))

In [15]:
items_to_df = {k:[] for k in train[0].keys()}

for item in train:
  for k, v in item.items():
    items_to_df[k].append(v)
  #items_to_df['split'].append('train')

ewt_df = pd.DataFrame(items_to_df)
ewt_df

Unnamed: 0,text,tokens,token_id,upos
0,Al-Zaman : American forces killed Shaikh Abdul...,"[Al, -, Zaman, :, American, forces, killed, Sh...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[PROPN, PUNCT, PROPN, PUNCT, ADJ, NOUN, VERB, ..."
1,[This killing of a respected cleric will be ca...,"[[, This, killing, of, a, respected, cleric, w...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[PUNCT, DET, NOUN, ADP, DET, ADJ, NOUN, AUX, A..."
2,DPA: Iraqi authorities announced that they had...,"[DPA, :, Iraqi, authorities, announced, that, ...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[PROPN, PUNCT, ADJ, NOUN, VERB, SCONJ, PRON, A..."
3,Two of them were being run by 2 officials of t...,"[Two, of, them, were, being, run, by, 2, offic...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[NUM, ADP, PRON, AUX, AUX, VERB, ADP, NUM, NOU..."
4,"The MoI in Iraq is equivalent to the US FBI, s...","[The, MoI, in, Iraq, is, equivalent, to, the, ...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[DET, PROPN, ADP, PROPN, AUX, ADJ, ADP, DET, P..."
...,...,...,...,...
12539,"Of course, they couldn't call him either to as...","[Of, course, ,, they, couldn't, could, n't, ca...","[1, 2, 3, 4, (5, -, 6), 5, 6, 7, 8, 9, 10, 11,...","[ADP, NOUN, PUNCT, PRON, _, AUX, PART, VERB, P..."
12540,On Monday I called and again it was a big to-d...,"[On, Monday, I, called, and, again, it, was, a...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...","[ADP, PROPN, PRON, VERB, CCONJ, ADV, PRON, AUX..."
12541,Supposedly they will be holding it for me this...,"[Supposedly, they, will, be, holding, it, for,...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, (13, -...","[ADV, PRON, AUX, AUX, VERB, PRON, ADP, PRON, D..."
12542,The employees at this Sear's are completely ap...,"[The, employees, at, this, Sear's, are, comple...","[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, (11, -, 12), 1...","[DET, NOUN, ADP, DET, PROPN, AUX, ADV, ADJ, CC..."


here we have 2 pos taggings upos (more general) and xpos (more specific).
we will consider the upos for simplicity

now we define a function to convert the ewt dataset to a token version where the upos and xpos are more clear

In [18]:
def convert_ewt_to_token(ewt_df):
  token_ewt_dict = {
      'tokens': [],
      'sentence_id': [],
      #'xpos': [],
      'upos': [],
      'token_id':[]
  }

  for row in ewt_df.iterrows():
    for token, upos, token_id in zip(row[1]['tokens'], row[1]['upos'], row[1]['token_id']):
      if isinstance(token_id, int):
        token_ewt_dict['tokens'].append(token)
        token_ewt_dict['sentence_id'].append(row[0])
        #token_ewt_dict['xpos'].append(xpos)
        token_ewt_dict['upos'].append(upos)
        token_ewt_dict['token_id'].append(token_id)

  return pd.DataFrame(token_ewt_dict)

In [19]:
token_ewt_df = convert_ewt_to_token(ewt_df)
token_ewt_df

Unnamed: 0,tokens,sentence_id,upos,token_id
0,Al,0,PROPN,1
1,-,0,PUNCT,2
2,Zaman,0,PROPN,3
3,:,0,PUNCT,4
4,American,0,ADJ,5
...,...,...,...,...
204572,on,12543,ADP,22
204573,my,12543,PRON,23
204574,car,12543,NOUN,24
204575,),12543,PUNCT,25


now let's define the function to get each token representation.

First let's understand how the tokenizer works and how to adapt the tokenizer tokens with the dataset tokens (token at word level)

In [None]:
sentence_id = 0
text = ewt_df['text'][sentence_id]
word_tokens = token_ewt_df[token_ewt_df['sentence_id']==sentence_id]['tokens'].to_list()

inputs = tokenizer_b(text, return_tensors="pt").to(model_b.device)
tokens = [t.replace('▁','') for t in tokenizer_b.convert_ids_to_tokens(inputs.input_ids[0])]

print('Original sentence tokens: ', word_tokens)
print('Tokens from the tokenizer: ', tokens)
print('Tokens different from the original sentence: ', [t for t in tokens if t not in word_tokens])

# this dictionary will contain an index and a list of sub tokens that compose the word
subtoken_dict = {i:[] for i in range(len(word_tokens))}
wt_count=0
subword=''
for i in range(len(tokens)):
  if subword+tokens[i] == word_tokens[wt_count]:
    subtoken_dict[wt_count].append(tokens[i])
    wt_count+=1
    subword=''
  else:
    subtoken_dict[wt_count].append(tokens[i])
    subword=subword+tokens[i]

print(subtoken_dict)

['Al', '-', 'Z', 'aman', '▁:', '▁American', '▁forces', '▁killed', '▁Shaikh', '▁Abdullah', '▁al', '-', 'Ani', ',', '▁the', '▁preacher', '▁at', '▁the', '▁mosque', '▁in', '▁the', '▁town', '▁of', '▁Q', 'aim', ',', '▁near', '▁the', '▁Syrian', '▁border', '.']
Original sentence tokens:  ['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.']
Tokens from the tokenizer:  ['Al', '-', 'Z', 'aman', '▁:', '▁American', '▁forces', '▁killed', '▁Shaikh', '▁Abdullah', '▁al', '-', 'Ani', ',', '▁the', '▁preacher', '▁at', '▁the', '▁mosque', '▁in', '▁the', '▁town', '▁of', '▁Q', 'aim', ',', '▁near', '▁the', '▁Syrian', '▁border', '.']
Tokens different from the original sentence:  ['Z', 'aman', '▁:', '▁American', '▁forces', '▁killed', '▁Shaikh', '▁Abdullah', '▁al', '▁the', '▁preacher', '▁at', '▁the', '▁mosque', '▁in', '▁the', '▁town', '▁of', '▁

defining a function to handle this behaviour

In [24]:
def get_subtokenization(sentence_id, tokenizer):

    text = ewt_df['text'][sentence_id]
    word_tokens = token_ewt_df[token_ewt_df['sentence_id']==sentence_id]['tokens'].tolist()

    # tokenizzazione con SentencePiece
    inputs = tokenizer(text, return_tensors="pt").to(model_b.device)
    tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

    # rimuove "▁"
    tokens = [t.replace("▁", "") for t in tokens]

    # dizionario finale
    subtoken_dict = {i: [] for i in range(len(word_tokens))}

    wt_count = 0        # indice parola UD
    subword = ""        # buffer concatenazione

    for tok in tokens:

        # se abbiamo già assegnato tutte le parole UD → stop
        if wt_count >= len(word_tokens):
            break

        target = word_tokens[wt_count]

        # aggiungo comunque il subtoken alla parola corrente (come nel tuo codice)
        subtoken_dict[wt_count].append(tok)
        subword += tok

        # match perfetto → avanza alla prossima parola
        if subword == target:
            wt_count += 1
            subword = ""

        # altrimenti continua (stessa logica tua)
        # NON introduco controlli addizionali
        # NON modifico l’algoritmo
        # semplicemente lascio scorrere come voleva il tuo approccio originale

    return subtoken_dict


let's check if it works

In [27]:
num_sentences = len(ewt_df)
c = 0
problematic_indexes = []

for i in tqdm(range(num_sentences)):
  subtoken_dict = get_subtokenization(i, tokenizer_b)

  sentence_token_check = []

  for k,v in subtoken_dict.items():
    subtoken_list = ''.join(v)
    sentence_token_check.append(subtoken_list)
  word_tokens = token_ewt_df[token_ewt_df['sentence_id']==i]['tokens'].to_list()
  if sentence_token_check != word_tokens:
    c=c+1 # c is the number of sentences where the subtoken aggregation differs from the 'dataset' tokenization
    problematic_indexes.append(i) # sentence to be removed later

assert(len(problematic_indexes)==c)

print('\nproblematic sentences: ', c)
print(f'problematic sentences (%): {c/(num_sentences)*100:.2f}%')

100%|██████████| 12544/12544 [00:20<00:00, 622.49it/s]


problematic sentences:  1647
problematic sentences (%): 13.13%





apparentely 1647 out of the 12543 sentences in the dataset have problems with this subtoken aggregation operation.

This happens because the tokenizer does not divide some elements, for example ":]" is kept by the tokenizer where in the dataset these are two tokens "." and "]".

we can consider removing these sentence as the dataset is still big enough for out scope

In [28]:
# drop these sentences from the original dataset
ewt_df.drop(index=problematic_indexes, inplace=True)
ewt_df.reset_index(drop=True, inplace=True)

# convert the new dataset in the token version
token_ewt_df = convert_ewt_to_token(ewt_df)
token_ewt_df

Unnamed: 0,tokens,sentence_id,upos,token_id
0,Al,0,PROPN,1
1,-,0,PUNCT,2
2,Zaman,0,PROPN,3
3,:,0,PUNCT,4
4,American,0,ADJ,5
...,...,...,...,...
162434,be,10896,AUX,19
162435,a,10896,DET,20
162436,huge,10896,ADJ,21
162437,ordeal,10896,NOUN,22


finally it's time to get the token representations with the subtoken considerations defined above

In [30]:
index = 0

sentence = ewt_df['text'][index]
inputs = tokenizer_b(sentence, return_tensors="pt").to(model_b.device)

model_b.eval()
start_token_id = tokenizer_b.bos_token_id
decoder_input_ids = torch.tensor([[start_token_id]], device=model_b.device)

with torch.no_grad():
    outputs = model_b(
        **inputs,
        decoder_input_ids=decoder_input_ids,
        output_hidden_states=True,
    )

encoder_hidden_states = [o.cpu() for o in outputs.encoder_hidden_states]
decoder_hidden_states = [o.cpu() for o in outputs.decoder_hidden_states]

layer = 2
subtokens = get_subtokenization(index, tokenizer_b)
token_representation = []
encoder_hidden_states = encoder_hidden_states[layer].squeeze(0)

token_index = 0
for k,v in subtokens.items():
  n = len(v)
  if n>1:
    mean_tensors_list = []
    for i in range(n):
      mean_tensors_list.append(encoder_hidden_states[token_index+i])
    mean = torch.mean(torch.stack(mean_tensors_list), dim=0)
    token_representation.append(mean)
  else:
    token_representation.append(encoder_hidden_states[token_index])
  token_index+=n

len(token_representation), len(token_ewt_df[token_ewt_df['sentence_id'] == index])

(29, 29)

In [31]:
# non batched
def get_word_representation_df(model, tokenizer):
  sentences = ewt_df['text'].to_list()
  num_encoder_layers = model_b.config.encoder.num_hidden_layers+1 # considering also the embedding layer

  model.eval()

  word_representation_dict = {}
  for e in range(num_encoder_layers):
    word_representation_dict[f'encoder_layer_{e+1}'] = []

  print('Starting to process sentences ...')
  for sentence_idx, sentence in tqdm(enumerate(sentences), total=len(sentences), desc='Processing sentences to get word representation'):
    inputs = tokenizer(sentence, return_tensors="pt").to(model.device)
    start_token_id = tokenizer.bos_token_id
    decoder_input_ids = torch.tensor([[start_token_id]], device=model.device)
    with torch.no_grad():
      outputs = model_b(
          **inputs,
          decoder_input_ids=decoder_input_ids,
          output_hidden_states=True,
      )
    encoder_hidden_states = torch.stack([e.cpu().squeeze(0) for e in outputs.encoder_hidden_states])

    subtokens = get_subtokenization(sentence_idx, tokenizer)
    for e in range(num_encoder_layers):
        token_representation = word_representation_dict[f'encoder_layer_{e+1}']
        ehs = encoder_hidden_states[e]
        token_index = 0
        for k,v in subtokens.items():
          n = len(v)
          if n>1:
            mean_tensors_list = []
            for i in range(n):
              mean_tensors_list.append(ehs[token_index+i])
            mean = torch.mean(torch.stack(mean_tensors_list), dim=0).to(torch.float32).numpy()
            token_representation.append(mean)
          else:
            token_representation.append(ehs[token_index].to(torch.float32).numpy())
          token_index+=n

  token_representation_df =pd.DataFrame(word_representation_dict)
  print('Sentence processed')
  return token_representation_df


token_representation_ewt_df = get_word_representation_df(model_b, tokenizer_b)
token_representation_ewt_df

Starting to process sentences ...


Processing sentences to get word representation: 100%|██████████| 10897/10897 [15:39<00:00, 11.60it/s]


Sentence processed


Unnamed: 0,encoder_layer_1,encoder_layer_2,encoder_layer_3,encoder_layer_4,encoder_layer_5,encoder_layer_6,encoder_layer_7,encoder_layer_8,encoder_layer_9,encoder_layer_10,encoder_layer_11,encoder_layer_12,encoder_layer_13
0,"[1.953125, -0.12451172, -1.4453125, 0.21972656...","[-0.1171875, 0.40820312, -1.21875, -1.453125, ...","[0.0234375, 0.06640625, -0.63671875, -0.703125...","[-0.29101562, 0.11328125, -0.953125, -0.554687...","[-0.001953125, 0.390625, -0.91796875, -0.85937...","[-0.78515625, 0.46875, -0.35546875, -0.6367187...","[-0.86328125, 0.42578125, -0.4140625, -0.31054...","[-1.125, 0.953125, -0.9375, -0.0625, 0.6171875...","[-0.3359375, 0.625, -2.796875, -1.65625, 1.625...","[-2.765625, 0.65625, -3.265625, -1.8515625, 1....","[-2.375, 0.55859375, -2.59375, -1.734375, 1.85...","[-2.828125, 1.6171875, -2.28125, -0.7890625, 2...","[-6.6875, 2.765625, -6.4375, -6.46875, 7.65625..."
1,"[0.6640625, -0.10253906, 0.0016937256, 0.51562...","[-0.640625, 0.46875, -0.091796875, -0.69921875...","[-0.84765625, 0.390625, -0.36328125, -0.238281...","[-0.47265625, 0.072265625, -0.65625, -0.470703...","[-0.89453125, 0.8828125, 0.2421875, -1.546875,...","[-1.1953125, 0.5546875, 0.73828125, -1.4453125...","[-1.4609375, 0.51953125, 1.2578125, -1.890625,...","[-1.671875, 0.69921875, -0.6015625, -0.9296875...","[-1.8203125, 1.0234375, -1.546875, -0.76953125...","[-1.9765625, 0.15039062, -3.078125, -1.640625,...","[-2.1875, 1.546875, -2.203125, -1.3359375, 2.7...","[-3.8125, 2.375, -0.734375, -0.73828125, 5.312...","[-13.5625, 5.71875, -3.921875, -2.265625, 13.0..."
2,"[3.4375, 0.12451172, -1.8359375, 1.140625, 0.8...","[-0.140625, 0.43164062, -1.609375, 0.022460938...","[-0.3984375, -0.4453125, -0.95703125, 0.519531...","[-0.6953125, -0.099609375, -1.234375, 0.285156...","[-0.05029297, -0.0146484375, -1.0859375, 0.296...","[-0.13867188, -0.4140625, -0.28320312, 0.5, 0....","[-0.8984375, -0.22753906, -0.26953125, 0.49609...","[-1.515625, -0.3984375, -0.32226562, 0.1289062...","[-0.72265625, -0.2421875, -0.40234375, 0.05859...","[-2.375, -0.6015625, -1.453125, 0.3125, 0.1699...","[-1.859375, -0.546875, -1.515625, 0.5546875, 0...","[-2.15625, 0.69140625, -0.7734375, 1.8046875, ...","[-5.5, 1.9375, -3.78125, 1.6484375, 3.109375, ..."
3,"[0.51171875, -0.60546875, 0.94921875, 1.34375,...","[-0.60546875, -0.015625, 0.62109375, 0.1230468...","[0.29296875, -0.00390625, 1.203125, -0.921875,...","[0.91796875, 0.091308594, 0.55078125, -0.17187...","[0.76953125, 0.28515625, 1.015625, -0.04150390...","[-0.16210938, -0.47460938, 0.28125, -0.3476562...","[0.053222656, -0.62890625, 0.25, 0.09375, -1.9...","[-0.51953125, 0.091796875, -1.6875, -0.0078125...","[1.3125, -0.953125, -0.6953125, -0.087890625, ...","[1.3984375, -0.4609375, -0.31054688, -1.25, -1...","[2.375, -0.38671875, 0.546875, -1.390625, -2.2...","[3.078125, 1.109375, 1.578125, -1.921875, -2.4...","[7.15625, 1.3984375, 4.5, -7.125, -3.875, 4.09..."
4,"[1.65625, 1.375, -0.32421875, 1.5546875, 0.730...","[-0.19140625, 0.77734375, -0.81640625, 0.78125...","[-0.39453125, 0.14257812, -0.57421875, 0.77734...","[-0.640625, 0.765625, -1.171875, 0.51953125, 1...","[-0.0859375, 1.3203125, -2.171875, 0.6171875, ...","[0.3046875, 0.53125, -2.359375, 0.5546875, 2.5...","[1.953125, 0.6640625, -0.48632812, 0.40234375,...","[1.0703125, -0.04296875, 0.49023438, 0.265625,...","[1.3125, -1.2109375, 0.72265625, 0.8359375, 2....","[0.9921875, -0.84375, -0.36328125, 0.15136719,...","[2.578125, -1.109375, 0.020996094, 0.31054688,...","[1.3125, -1.1875, 1.015625, 2.21875, 2.796875,...","[4.65625, -3.703125, -0.02355957, 2.0625, 6.06..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
162434,"[0.27148438, -0.20605469, -0.033203125, -0.605...","[-0.15625, -0.49804688, -0.22070312, -0.339843...","[0.890625, 0.21875, 0.33984375, -0.515625, 0.1...","[0.73046875, -0.068359375, -0.21875, 0.0737304...","[0.390625, 0.1484375, 0.19335938, 0.123046875,...","[-0.53125, -0.29492188, -0.6484375, 0.04199218...","[-0.95703125, 0.14160156, 0.015625, 0.03125, 0...","[-0.79296875, -0.9296875, -0.7265625, -0.35351...","[0.3203125, 0.515625, -2.65625, -1.4375, 0.167...","[-0.4765625, 0.34960938, -2.234375, -2.15625, ...","[0.421875, 2.1875, -1.875, -2.359375, 0.016601...","[-0.18359375, 1.40625, -2.015625, -2.4375, 0.4...","[-1.6015625, -0.66015625, -4.71875, -6.8125, 1..."
162435,"[0.85546875, -0.33789062, -0.16894531, -0.1474...","[-0.48828125, 1.1484375, 0.0146484375, -0.5898...","[-0.41015625, 0.79296875, 0.13085938, -0.12890...","[0.42382812, 0.6328125, -0.15820312, -0.726562...","[-0.7265625, 1.3515625, 0.25585938, -0.2871093...","[-0.091796875, 0.75390625, 0.5703125, 0.097656...","[-0.6953125, -0.19921875, 0.46484375, -0.78906...","[-0.5625, -0.56640625, 1.796875, -0.5625, -0.1...","[0.0546875, 2.84375, -2.109375, -0.078125, -0....","[1.2890625, 1.578125, -0.03149414, -1.171875, ...","[2.84375, 3.71875, 1.984375, -1.453125, -2.812...","[6.25, 5.375, 4.8125, -0.4921875, -7.625, 1.93...","[5.125, 2.609375, 1.8125, -0.21289062, -4.8125..."
162436,"[0.8515625, -0.36914062, 0.8203125, 0.33789062...","[-0.3359375, -0.703125, -0.234375, -0.02148437...","[-0.49804688, -0.6796875, -0.040039062, 0.1406...","[0.34375, -1.3828125, -0.41796875, 0.41015625,...","[0.82421875, -0.12890625, -0.12988281, -0.0390...","[0.03125, -0.6875, -0.3359375, 0.21289062, 1.5...","[-1.28125, -1.3125, 0.28125, 0.3984375, 1.8125...","[-1.3046875, -1.4609375, -1.21875, 0.12890625,...","[-0.49609375, -1.25, -1.828125, 0.15820312, 1....","[-0.109375, -2.1875, -1.3671875, 1.15625, 1.96...","[0.58984375, -1.671875, -0.4140625, 1.21875, 2...","[-0.16796875, -2.125, -0.875, 1.5625, 3.59375,...","[0.17480469, -7.875, -3.5625, 1.546875, 7.5625..."
162437,"[2.765625, -0.484375, -0.62890625, -0.39453125...","[0.15625, -0.20605469, -0.96875, -0.24316406, ...","[0.98046875, -0.14648438, -0.26171875, -0.5234...","[0.88671875, 0.28515625, -0.037109375, -0.6562...","[1.1484375, -0.15625, -0.375, -0.6875, 0.71093...","[-0.078125, -0.26171875, -0.43164062, -0.76171...","[-1.1171875, -0.85546875, -0.2109375, -1.46093...","[-0.671875, -1.015625, -0.9453125, -1.2890625,...","[0.46289062, -0.01953125, -1.859375, -0.921875...","[0.076171875, 0.484375, -2.125, -1.578125, 1.6...","[1.1328125, 1.0078125, -1.8828125, -2.296875, ...","[0.99609375, 0.0703125, -1.9765625, -2.21875, ...","[0.85546875, -1.546875, -3.828125, -5.84375, 6..."


now let's consider the labels

we will consider the base label, with the xpos and upos tags and also the control task

In [34]:
# defining the POS tags
#xpos_labels=token_ewt_df['xpos'].unique()
upos_labels=token_ewt_df['upos'].unique()

#xpos_tags={x:i for i,x in enumerate(xpos_labels)}
upos_tags={u:i for i,u in enumerate(upos_labels)}

# inserting the tags in the dataset
#token_ewt_df['xpos_tag']=token_ewt_df['xpos'].map(lambda xpos: xpos_tags[xpos])
token_ewt_df['upos_tag']=token_ewt_df['upos'].map(lambda upos: upos_tags[upos])
token_ewt_df.drop(columns=['upos', 'token_id'], inplace=True)

In [35]:
unique_tokens = list(token_ewt_df['tokens'].unique())
np.random.shuffle(unique_tokens)

#num_xpos_tags = len(xpos_tags)
num_upos_tags = len(upos_tags)

#token_ct_map_xpos={x:i%num_xpos_tags for i,x in enumerate(unique_tokens)} # token control task map for xpos
token_ct_map_upos={x:i%num_upos_tags for i,x in enumerate(unique_tokens)}

# adding the control task tags to the dataframe
#token_ewt_df['ct_xpos_tag']=token_ewt_df['tokens'].map(lambda x: token_ct_map_xpos[x])
token_ewt_df['ct_upos_tag']=token_ewt_df['tokens'].map(lambda u: token_ct_map_upos[u])

In [37]:
token_ewt_df = pd.concat([token_ewt_df, token_representation_ewt_df], axis=1)
token_ewt_df

Unnamed: 0,tokens,sentence_id,upos_tag,ct_upos_tag,encoder_layer_1,encoder_layer_2,encoder_layer_3,encoder_layer_4,encoder_layer_5,encoder_layer_6,encoder_layer_7,encoder_layer_8,encoder_layer_9,encoder_layer_10,encoder_layer_11,encoder_layer_12,encoder_layer_13
0,Al,0,0,16,"[1.953125, -0.12451172, -1.4453125, 0.21972656...","[-0.1171875, 0.40820312, -1.21875, -1.453125, ...","[0.0234375, 0.06640625, -0.63671875, -0.703125...","[-0.29101562, 0.11328125, -0.953125, -0.554687...","[-0.001953125, 0.390625, -0.91796875, -0.85937...","[-0.78515625, 0.46875, -0.35546875, -0.6367187...","[-0.86328125, 0.42578125, -0.4140625, -0.31054...","[-1.125, 0.953125, -0.9375, -0.0625, 0.6171875...","[-0.3359375, 0.625, -2.796875, -1.65625, 1.625...","[-2.765625, 0.65625, -3.265625, -1.8515625, 1....","[-2.375, 0.55859375, -2.59375, -1.734375, 1.85...","[-2.828125, 1.6171875, -2.28125, -0.7890625, 2...","[-6.6875, 2.765625, -6.4375, -6.46875, 7.65625..."
1,-,0,1,11,"[0.6640625, -0.10253906, 0.0016937256, 0.51562...","[-0.640625, 0.46875, -0.091796875, -0.69921875...","[-0.84765625, 0.390625, -0.36328125, -0.238281...","[-0.47265625, 0.072265625, -0.65625, -0.470703...","[-0.89453125, 0.8828125, 0.2421875, -1.546875,...","[-1.1953125, 0.5546875, 0.73828125, -1.4453125...","[-1.4609375, 0.51953125, 1.2578125, -1.890625,...","[-1.671875, 0.69921875, -0.6015625, -0.9296875...","[-1.8203125, 1.0234375, -1.546875, -0.76953125...","[-1.9765625, 0.15039062, -3.078125, -1.640625,...","[-2.1875, 1.546875, -2.203125, -1.3359375, 2.7...","[-3.8125, 2.375, -0.734375, -0.73828125, 5.312...","[-13.5625, 5.71875, -3.921875, -2.265625, 13.0..."
2,Zaman,0,0,12,"[3.4375, 0.12451172, -1.8359375, 1.140625, 0.8...","[-0.140625, 0.43164062, -1.609375, 0.022460938...","[-0.3984375, -0.4453125, -0.95703125, 0.519531...","[-0.6953125, -0.099609375, -1.234375, 0.285156...","[-0.05029297, -0.0146484375, -1.0859375, 0.296...","[-0.13867188, -0.4140625, -0.28320312, 0.5, 0....","[-0.8984375, -0.22753906, -0.26953125, 0.49609...","[-1.515625, -0.3984375, -0.32226562, 0.1289062...","[-0.72265625, -0.2421875, -0.40234375, 0.05859...","[-2.375, -0.6015625, -1.453125, 0.3125, 0.1699...","[-1.859375, -0.546875, -1.515625, 0.5546875, 0...","[-2.15625, 0.69140625, -0.7734375, 1.8046875, ...","[-5.5, 1.9375, -3.78125, 1.6484375, 3.109375, ..."
3,:,0,1,8,"[0.51171875, -0.60546875, 0.94921875, 1.34375,...","[-0.60546875, -0.015625, 0.62109375, 0.1230468...","[0.29296875, -0.00390625, 1.203125, -0.921875,...","[0.91796875, 0.091308594, 0.55078125, -0.17187...","[0.76953125, 0.28515625, 1.015625, -0.04150390...","[-0.16210938, -0.47460938, 0.28125, -0.3476562...","[0.053222656, -0.62890625, 0.25, 0.09375, -1.9...","[-0.51953125, 0.091796875, -1.6875, -0.0078125...","[1.3125, -0.953125, -0.6953125, -0.087890625, ...","[1.3984375, -0.4609375, -0.31054688, -1.25, -1...","[2.375, -0.38671875, 0.546875, -1.390625, -2.2...","[3.078125, 1.109375, 1.578125, -1.921875, -2.4...","[7.15625, 1.3984375, 4.5, -7.125, -3.875, 4.09..."
4,American,0,2,7,"[1.65625, 1.375, -0.32421875, 1.5546875, 0.730...","[-0.19140625, 0.77734375, -0.81640625, 0.78125...","[-0.39453125, 0.14257812, -0.57421875, 0.77734...","[-0.640625, 0.765625, -1.171875, 0.51953125, 1...","[-0.0859375, 1.3203125, -2.171875, 0.6171875, ...","[0.3046875, 0.53125, -2.359375, 0.5546875, 2.5...","[1.953125, 0.6640625, -0.48632812, 0.40234375,...","[1.0703125, -0.04296875, 0.49023438, 0.265625,...","[1.3125, -1.2109375, 0.72265625, 0.8359375, 2....","[0.9921875, -0.84375, -0.36328125, 0.15136719,...","[2.578125, -1.109375, 0.020996094, 0.31054688,...","[1.3125, -1.1875, 1.015625, 2.21875, 2.796875,...","[4.65625, -3.703125, -0.02355957, 2.0625, 6.06..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
162434,be,10896,9,10,"[0.27148438, -0.20605469, -0.033203125, -0.605...","[-0.15625, -0.49804688, -0.22070312, -0.339843...","[0.890625, 0.21875, 0.33984375, -0.515625, 0.1...","[0.73046875, -0.068359375, -0.21875, 0.0737304...","[0.390625, 0.1484375, 0.19335938, 0.123046875,...","[-0.53125, -0.29492188, -0.6484375, 0.04199218...","[-0.95703125, 0.14160156, 0.015625, 0.03125, 0...","[-0.79296875, -0.9296875, -0.7265625, -0.35351...","[0.3203125, 0.515625, -2.65625, -1.4375, 0.167...","[-0.4765625, 0.34960938, -2.234375, -2.15625, ...","[0.421875, 2.1875, -1.875, -2.359375, 0.016601...","[-0.18359375, 1.40625, -2.015625, -2.4375, 0.4...","[-1.6015625, -0.66015625, -4.71875, -6.8125, 1..."
162435,a,10896,5,3,"[0.85546875, -0.33789062, -0.16894531, -0.1474...","[-0.48828125, 1.1484375, 0.0146484375, -0.5898...","[-0.41015625, 0.79296875, 0.13085938, -0.12890...","[0.42382812, 0.6328125, -0.15820312, -0.726562...","[-0.7265625, 1.3515625, 0.25585938, -0.2871093...","[-0.091796875, 0.75390625, 0.5703125, 0.097656...","[-0.6953125, -0.19921875, 0.46484375, -0.78906...","[-0.5625, -0.56640625, 1.796875, -0.5625, -0.1...","[0.0546875, 2.84375, -2.109375, -0.078125, -0....","[1.2890625, 1.578125, -0.03149414, -1.171875, ...","[2.84375, 3.71875, 1.984375, -1.453125, -2.812...","[6.25, 5.375, 4.8125, -0.4921875, -7.625, 1.93...","[5.125, 2.609375, 1.8125, -0.21289062, -4.8125..."
162436,huge,10896,2,8,"[0.8515625, -0.36914062, 0.8203125, 0.33789062...","[-0.3359375, -0.703125, -0.234375, -0.02148437...","[-0.49804688, -0.6796875, -0.040039062, 0.1406...","[0.34375, -1.3828125, -0.41796875, 0.41015625,...","[0.82421875, -0.12890625, -0.12988281, -0.0390...","[0.03125, -0.6875, -0.3359375, 0.21289062, 1.5...","[-1.28125, -1.3125, 0.28125, 0.3984375, 1.8125...","[-1.3046875, -1.4609375, -1.21875, 0.12890625,...","[-0.49609375, -1.25, -1.828125, 0.15820312, 1....","[-0.109375, -2.1875, -1.3671875, 1.15625, 1.96...","[0.58984375, -1.671875, -0.4140625, 1.21875, 2...","[-0.16796875, -2.125, -0.875, 1.5625, 3.59375,...","[0.17480469, -7.875, -3.5625, 1.546875, 7.5625..."
162437,ordeal,10896,3,15,"[2.765625, -0.484375, -0.62890625, -0.39453125...","[0.15625, -0.20605469, -0.96875, -0.24316406, ...","[0.98046875, -0.14648438, -0.26171875, -0.5234...","[0.88671875, 0.28515625, -0.037109375, -0.6562...","[1.1484375, -0.15625, -0.375, -0.6875, 0.71093...","[-0.078125, -0.26171875, -0.43164062, -0.76171...","[-1.1171875, -0.85546875, -0.2109375, -1.46093...","[-0.671875, -1.015625, -0.9453125, -1.2890625,...","[0.46289062, -0.01953125, -1.859375, -0.921875...","[0.076171875, 0.484375, -2.125, -1.578125, 1.6...","[1.1328125, 1.0078125, -1.8828125, -2.296875, ...","[0.99609375, 0.0703125, -1.9765625, -2.21875, ...","[0.85546875, -1.546875, -3.828125, -5.84375, 6..."


In [39]:
token_ewt_df['encoder_layer_7'][100].shape

(768,)

## MultiNLI

In [None]:
multinli_dataset = load_dataset("nyu-mll/multi_nli")

## ParaRel