In [None]:
import torch
import random
from word_emb_evaluation import words, ABX
from utils import load_model, configure_environment


configure_environment(device="cuda")
bert, bert_tokenizer, device = load_model(model_name="allegro/herbert-base-cased")
papuga, papuga_tokenizer, device = load_model(model_name="flax-community/papuGaPT2", causal=True)

Seed set to 122348


Device set to cuda


Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.bias', 'cls.sso.sso_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def papuga_word_embedding(word):
	def papuga_tokenize(word):
		ids = papuga_tokenizer(word, return_tensors='pt')['input_ids'][0]
		return [papuga_tokenizer.decode(n) for n in ids]

	emb = papuga.transformer.wte.weight.detach().cpu().numpy()

	tokens = papuga_tokenize(word)
	return torch.tensor(emb[[papuga_tokenizer.encode(token) for token in tokens]], dtype=float).mean(dim=0)

In [18]:
def bert_word_embedding(word):
    input_ids = bert_tokenizer(word, return_tensors='pt')['input_ids'] #.to(device)
    output = bert(input_ids=input_ids)
    return output.last_hidden_state[0,0,:]

In [19]:
papuga_wembedds = {word: papuga_word_embedding(word) for word in words}
bert_wembedds = {word: bert_word_embedding(word) for word in words}

In [20]:
ABX(papuga_wembedds)

PROBLEMS: 0.0
Start
TOTAL SCORE: 0.58707


In [21]:
ABX(bert_wembedds)

PROBLEMS: 0.0
Start
TOTAL SCORE: 0.591602


In [None]:
def character_swap(word):
	idx1, idx2 = random.sample(range(len(word)), 2)
	word = list(word)
	word[idx1], word[idx2] = word[idx2], word[idx1]
	return "".join(word)

swapped_bert_wembedds = {word: bert_word_embedding(character_swap(word)) for word in words}

In [28]:
ABX(swapped_bert_wembedds)

PROBLEMS: 0.0
Start
TOTAL SCORE: 0.527092


In [29]:
qwerty_neighbors = {
    'q': ['w', 'a'], 'w': ['q', 'e', 'a', 's'], 'e': ['w', 'r', 's', 'd'],
    'r': ['e', 't', 'd', 'f'], 't': ['r', 'y', 'f', 'g'], 'y': ['t', 'u', 'g', 'h'],
    'u': ['y', 'i', 'h', 'j'], 'i': ['u', 'o', 'j', 'k'], 'o': ['i', 'p', 'k', 'l'],
    'p': ['o', 'l'], 'a': ['q', 'w', 's', 'z'], 's': ['a', 'w', 'e', 'z', 'x'],
    'd': ['s', 'e', 'r', 'x', 'c'], 'f': ['d', 'r', 't', 'c', 'v'], 'g': ['f', 't', 'y', 'v', 'b'],
    'h': ['g', 'y', 'u', 'b', 'n'], 'j': ['h', 'u', 'i', 'n', 'm'], 'k': ['j', 'i', 'o', 'm'],
    'l': ['k', 'o', 'p'], 'z': ['a', 's', 'x'], 'x': ['z', 's', 'd', 'c'], 'c': ['x', 'd', 'f', 'v'],
    'v': ['c', 'f', 'g', 'b'], 'b': ['v', 'g', 'h', 'n'], 'n': ['b', 'h', 'j', 'm'],
    'm': ['n', 'j', 'k']
}

def random_qwerty_swap(word, swap_count=1):
    # Convert the word to a list of characters for mutability
    word_chars = list(word)
    indices = list(range(len(word_chars)))
    random.shuffle(indices)  # Shuffle indices to choose random positions

    for _ in range(min(swap_count, len(word))):  # Ensure we don't exceed the word length
        idx = indices.pop()  # Take a random index
        char = word_chars[idx].lower()  # Get the character and make it lowercase
        if char in qwerty_neighbors:  # Check if the char is in our mapping
            replacement = random.choice(qwerty_neighbors[char])
            word_chars[idx] = replacement

    return ''.join(word_chars)


In [30]:
qwerty_bert_wembedds = {word: bert_word_embedding(random_qwerty_swap(word)) for word in words}

In [31]:
ABX(qwerty_bert_wembedds)

PROBLEMS: 0.0
Start
TOTAL SCORE: 0.528666
