In [1]:
import torch
import random
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModel
from torch.nn import functional as F
from transformers import logging

logging.set_verbosity_error()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
def cos(a, b):
    return a.dot(b) / (a.dot(a) * b.dot(b)) ** 0.5

### a)

In [3]:
model_name = 'flax-community/papuGaPT2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

def tokenize(word):
    ids = tokenizer(word, return_tensors='pt')['input_ids'][0]
    return [tokenizer.decode(n) for n in ids]

emb = model.transformer.wte.weight.detach().cpu().numpy()

In [4]:
def get_context_free_token_embedding(word):
    token_ids = tokenizer.encode(' ' + word)
    whole_word_embbeding = np.sum(np.stack([emb[id] for id in token_ids]), axis=0) / len(token_ids)
    return whole_word_embbeding

### b)

In [5]:
model_name_encoder = "allegro/herbert-base-cased"
tokenizer_encoder = AutoTokenizer.from_pretrained(model_name_encoder)
model_encoder = AutoModel.from_pretrained(model_name_encoder).to(device)

In [6]:
def get_context_token_embedding(word):
    input_ids = tokenizer_encoder(word, return_tensors='pt')['input_ids'].to(device)
    output = model_encoder(input_ids=input_ids)
    return output.last_hidden_state.detach().cpu().numpy()[0,0,:]

### testing

In [7]:
male = 'męzczyzna'
female = 'kobieta'

uncle = 'wujek'
aunt = 'ciocia'
king = 'król'
queen = 'królowa'

In [16]:
dog = 'pies'
cat = 'kot'
babydog = 'szczeniak'
wolf = 'wilk'

### a)

feminine vector

In [14]:
male_emb = get_context_free_token_embedding(male)
female_emb = get_context_free_token_embedding(female)
feminity = female_emb - male_emb

res1 = get_context_free_token_embedding(aunt) - get_context_free_token_embedding(uncle)
res2 = get_context_free_token_embedding(queen) - get_context_free_token_embedding(king)

print(cos(feminity, res1), cos(feminity, res2))

-0.00035056360659630706 0.134997398412024


abx

In [17]:
dog_emb = get_context_free_token_embedding(dog)
cat_emb = get_context_free_token_embedding(cat)
babydog_emb = get_context_free_token_embedding(babydog)
wolf_emb = get_context_free_token_embedding(wolf)

print(cos(dog_emb, babydog_emb), cos(cat_emb, babydog_emb))
print(cos(dog_emb, wolf_emb), cos(cat_emb, wolf_emb))

0.6544517114198857 0.502280773741618
0.5424712195440118 0.49661500595562624


### b)

feminine vector

In [18]:
male_emb = get_context_token_embedding(male)
female_emb = get_context_token_embedding(female)
feminity = female_emb - male_emb

temp1 = get_context_token_embedding(aunt) - get_context_token_embedding(uncle)
temp2 = get_context_token_embedding(queen) - get_context_token_embedding(king)

print(cos(feminity, temp1), cos(feminity, temp2))

-0.11467732098042435 -0.14195259783483538


abx

In [19]:
dog_emb = get_context_token_embedding(dog)
cat_emb = get_context_token_embedding(cat)
babygod_emb = get_context_token_embedding(babydog)
wolf_emb = get_context_token_embedding(wolf)

print(cos(dog_emb, babygod_emb), cos(cat_emb, babygod_emb))
print(cos(dog_emb, wolf_emb), cos(cat_emb, wolf_emb))

0.8315782639165055 0.8348302009099379
0.9753929077482025 0.9832040527884389


In [None]:
...