In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import functional as F
import random

model_name = 'flax-community/papuGaPT2'
device = 'cuda'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

  return self.fget.__get__(instance, owner)()


In [2]:
print(model.num_parameters() / 1024 / 1024)

118.675048828125


In [4]:
text = "Ta wiewiórka jest sprytna. Ale ze mną nie wygra!"
ids = tokenizer(text, return_tensors='pt')['input_ids'][0]
tokens = [tokenizer.decode(n) for n in ids]
print(tokens, len(tokens))
print(*tokens, sep='')

['Ta', ' wie', 'wi', 'órka', ' jest', ' spry', 'tna', '.', ' Ale', ' ze', ' mną', ' nie', ' wygra', '!'] 14
Ta wiewiórka jest sprytna. Ale ze mną nie wygra!


In [5]:
input_ids = tokenizer(text, return_tensors='pt')['input_ids'].to(device)
with torch.no_grad():
    output = model(input_ids=input_ids)
print (output.logits.shape)    

torch.Size([1, 14, 50257])


In [8]:
#out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
#out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
#out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 0, torch.tensor([[0, 0], [1, 0]]))

tensor([[1, 2],
        [3, 2]])

In [12]:
a = torch.tensor( [[1,2,3], [4,5,6]])
for n in range(3):
    print (a.unsqueeze(n))

tensor([[[1, 2, 3],
         [4, 5, 6]]])
tensor([[[1, 2, 3]],

        [[4, 5, 6]]])
tensor([[[1],
         [2],
         [3]],

        [[4],
         [5],
         [6]]])


In [18]:
but_last_logits = output.logits[:, :-1, :]
but_first_labels = input_ids[:, 1:]

logp = F.log_softmax(but_last_logits, dim=-1)
labels = but_first_labels.unsqueeze(2)
print (logp.shape, labels.shape)

gathered = torch.gather(logp, 2, labels)
print (gathered.shape)

    

torch.Size([1, 13, 50257]) torch.Size([1, 13, 1])
torch.Size([1, 13, 1])


In [21]:
def log_probs_from_logits(logits, labels):
    logp = F.log_softmax(logits, dim=-1)
    
    logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logp_label
    
            
def sentence_prob(sentence_txt):
    input_ids = tokenizer(sentence_txt, return_tensors='pt')['input_ids'].to(device)
    with torch.no_grad():
        output = model(input_ids=input_ids)
        log_probs = log_probs_from_logits(output.logits[:, :-1, :], input_ids[:, 1:])
        seq_log_probs = torch.sum(log_probs)
    return seq_log_probs.cpu().numpy()    


In [22]:
words = 'Ala ma dwa tłuste koty i ślicznego kanarka'.split()

for i in range(1,len(words)+1):
    txt = ' '.join(words[:i])
    print (txt, sentence_prob(txt))

Ala -7.5044913
Ala ma -11.766838
Ala ma dwa -17.383055
Ala ma dwa tłuste -28.23169
Ala ma dwa tłuste koty -33.956753
Ala ma dwa tłuste koty i -35.879913
Ala ma dwa tłuste koty i ślicznego -45.49717
Ala ma dwa tłuste koty i ślicznego kanarka -52.10604


In [23]:
options = ['kota', 'psa', 'słonia', 'nosorożca', 'krowę', 'konia', 'pereturbację', 'perturbację']

for opt in options:
    print (opt, sentence_prob('Ala ma ' + opt))

kota -18.377638
psa -20.369291
słonia -25.758102
nosorożca -27.199951
krowę -22.857933
konia -23.28639
pereturbację -43.974747
perturbację -32.565228


In [24]:
def normalized_sentence_prob(txt):
    length = len(tokenizer(text, return_tensors='pt')['input_ids'][0])
    return sentence_prob(txt) / length

for opt in options:
    print (opt, normalized_sentence_prob('Ala ma ' + opt))

kota -1.3126884187970842
psa -1.4549493789672852
słonia -1.8398644583565849
nosorożca -1.9428536551339286
krowę -1.6327095031738281
konia -1.663313593183245
pereturbację -3.1410533360072543
perturbację -2.32608767918178


In [25]:
for opt in options:
    print (opt, normalized_sentence_prob('Ala nie ma ' + opt))

kota -1.779935019356864
psa -1.8222836085728236
słonia -2.0925118582589284
nosorożca -2.3730441502162387
krowę -2.592665263584682
konia -1.8791062491280692
pereturbację -4.007100786481585
perturbację -3.0986230032784596
