In [1]:
import datasets
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "roneneldan/TinyStories-28M"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.pad_token = tokenizer.eos_token


In [125]:
def get_sentence_ppl(sentence, bos=False):
    # Tokenize the sentence
    loss_fct = CrossEntropyLoss(reduction="none")

    input = tokenizer(
            sentence,
            add_special_tokens=False,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
            return_attention_mask=True,
        ).to(device)
    input_ids = input["input_ids"].to(device)
    attn_mask = input["attention_mask"].to(device)
    if bos:
        bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]]).to(device)
        input_ids = torch.cat([bos_tokens_tensor, input_ids], dim=1).to(device)
        attn_mask = torch.cat(
                        [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
                    )
    
    # Get log probabilities from the model
    with torch.no_grad():
        logits = model(input_ids, attention_mask=attn_mask,  labels=input_ids).logits
        
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous()
    # shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
    ce_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # ce_loss = loss_fct(shift_logits.transpose(1, 2), shift_labels)
    perplexities = torch.exp(ce_loss)
    return shift_labels, perplexities



In [127]:
sentence = "Once upon a time, there was a friendly bird named Bob. Bob lived near a big cliff. Every day, Bob"
token_ids, ppl = get_sentence_ppl(sentence, bos=True)
# print(result, np.mean(list(result.values())[1:]))
# print tokenwise_ppl
token_ids = token_ids.cpu().numpy().tolist()[0]
ppl = ppl.cpu().numpy().tolist()
for token_id, ppl in zip(token_ids, ppl):
    print(tokenizer.decode(token_id), ppl)
print(np.mean(ppl))
    


Once 3786911.5
 upon 1.0253034830093384
 a 1.0000678300857544
 time 1.0001612901687622
, 1.0017625093460083
 there 1.0782173871994019
 was 1.0065215826034546
 a 1.0115015506744385
 friendly 478.8094787597656
 bird 105.92422485351562
 named 1.4727593660354614
 Bob 5.523642539978027
. 1.6974177360534668
 Bob 1.0217628479003906
 lived 2.947664737701416
 near 20.135814666748047
 a 1.1660058498382568
 big 1.3090853691101074
 cliff 6.568779945373535
. 1.0390673875808716
 Every 5.105368614196777
 day 1.0388258695602417
, 1.0028589963912964
 Bob 2.163477897644043
2.163477897644043


In [130]:
import evaluate
perplexity = evaluate.load("perplexity", module_type="metric")
results = perplexity.compute(model_id='roneneldan/TinyStories-28M',
                             add_start_token=False,
                             predictions=[sentence])
print(results)

Using pad_token, but it is not set yet.
100%|██████████| 1/1 [00:00<00:00,  7.36it/s]

{'perplexities': [2.667530059814453], 'mean_perplexity': 2.667530059814453}





In [131]:
# load tinystories
path = f'../tinystories_words/tinystories_rows_gpt4.txt'
tinystories = []
with open(path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        tinystories.append(line.strip())
path = f'../tinystories_words/tinystories_rows_gpt4.txt'
tinystories_gpt4 = []
with open(path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        tinystories_gpt4.append(line.strip())

In [132]:
# get ppl of a random story
import random
random_story = random.choice(tinystories)
print(random_story)
results = perplexity.compute(model_id='roneneldan/TinyStories-28M',
                                add_start_token=False,
                                predictions=[random_story])
print(results)

Once upon a time, there was a big pumpkin. It was hot outside, and the pumpkin was sad. It wanted to be cool like the other pumpkins in the garden. The pumpkin saw a bird sitting on a tree branch and asked, "Can you teach me how to be cool like the other pumpkins?"The bird said, "I can teach you how to make a big shade. Then you will be cool like the other pumpkins." The pumpkin was happy to learn from the bird. So, they worked together to make a big shade using leaves and sticks.But then, a strong wind came and blew the shade away. The pumpkin was sad again. The bird told the pumpkin, "It's okay. Sometimes things don't work out the way we want them to. The important thing is to keep trying and never give up."The pumpkin learned to be strong and keep trying, even when things were hard. And soon, the weather changed, and the pumpkin was cool and happy like the other pumpkins. The moral of the story is to never give up and keep trying, even when things are tough.


Using pad_token, but it is not set yet.
100%|██████████| 1/1 [00:01<00:00,  1.31s/it]

{'perplexities': [2.9053306579589844], 'mean_perplexity': 2.9053306579589844}





In [133]:
# get token-level ppl of a random story
token_ids, ppl = get_sentence_ppl(random_story, bos=True)
token_ids = token_ids.cpu().numpy().tolist()[0]
ppl = ppl.cpu().numpy().tolist()
for token_id, ppl in zip(token_ids, ppl):
    print(tokenizer.decode(token_id), ppl)
print(np.mean(ppl))

Once 3786918.75
 upon 1.0253037214279175
 a 1.0000678300857544
 time 1.0001612901687622
, 1.0017625093460083
 there 1.0782172679901123
 was 1.0065215826034546
 a 1.0115015506744385
 big 10.045510292053223
 pumpkin 543.5086669921875
. 1.3449971675872803
 It 1.2260874509811401
 was 1.0084749460220337
 hot 2340.401123046875
 outside 1.0120939016342163
, 1.9742428064346313
 and 1.8300126791000366
 the 1.0309021472930908
 pumpkin 1.0310642719268799
 was 1.666839838027954
 sad 11.699323654174805
. 1.2397280931472778
 It 1.191320776939392
 wanted 1.0967988967895508
 to 1.0277469158172607
 be 2.52840518951416
 cool 4.993949890136719
 like 7.950096130371094
 the 1.0992697477340698
 other 1.086223840713501
 pump 1.0618982315063477
kins 1.250278353691101
 in 14.077720642089844
 the 1.0086685419082642
 garden 4.384374141693115
. 1.0216915607452393
 The 5.038606643676758
 pumpkin 5.195713043212891
 saw 16.96390151977539
 a 1.251333475112915
 bird 54.263065338134766
 sitting 32.592655181884766
 on 1

In [134]:
# load the tb and fb stories
tb_stories = []
fb_stories = []
pos_tb_stories = []
pos_fb_stories = []
neg_tb_stories = []
neg_fb_stories = []

tb_cond_file  = f'../../data/conditions/tinytom-v3/0_forward_belief_true_belief/corrected.txt'
fb_cond_file  = f'../../data/conditions/tinytom-v3/0_forward_belief_false_belief/corrected.txt'
with open(tb_cond_file, 'r') as f:
    lines = f.readlines()
    for l, line in enumerate(lines):
        tb_stories.append(line.strip())

with open(fb_cond_file, 'r') as f:
    lines = f.readlines()
    for l, line in enumerate(lines):        
        fb_stories.append(line.strip())


In [135]:
random_story = random.choice(tb_stories)
token_ids, ppl = get_sentence_ppl(random_story, bos=True)
token_ids = token_ids.cpu().numpy().tolist()[0]
ppl = ppl.cpu().numpy().tolist()
for token_id, ppl in zip(token_ids, ppl):
    print(tokenizer.decode(token_id), ppl)
print(np.mean(ppl))

Once 3786918.75
 upon 1.0253037214279175
 a 1.0000678300857544
 time 1.0001612901687622
, 1.0017625093460083
 in 26.137815475463867
 a 1.0432683229446411
 lovely 255.73524475097656
 park 147.6114044189453
 full 1578.48095703125
 of 1.004035472869873
 pretty 59.270225524902344
 flowers 1.1265602111816406
, 6.635289192199707
 there 1.136163353919983
 was 1.290573239326477
 a 1.004797101020813
 little 1.7714539766311646
 girl 1.9726015329360962
 named 1.0066407918930054
 Queen 16705913.0
ie 1.6073189973831177
. 1.0317797660827637
 She 6.450061798095703
 needed 1701482.75
 a 11.8333740234375
 green 9422.484375
 leaf 5.010408401489258
 for 4.829154014587402
 her 1.0386905670166016
 art 878.48779296875
 project 2.2610509395599365
. 1.047688364982605
 She 4.079145908355713
 spotted 51498.6015625
 a 1.6199395656585693
 leaf 25.733139038085938
 that 9.338224411010742
 was 1.553296446800232
 very 19.732328414916992
 green 13.203532218933105
 and 3.0927369594573975
 fresh 5605.99462890625
. 1.009