## Generate Probibilities for bigrams

In [7]:
import json
import pickle

In [8]:
with open("tinystories_ngrams/bigram_counts.json", "r") as f:
    bigram_counts = json.load(f)

total_count = sum(bigram_counts.values())

for bigram in bigram_counts.keys():
    bigram_counts[bigram] = bigram_counts[bigram] / total_count

bigram_probs = bigram_counts

In [None]:
with open("tinystories_ngrams/bigram_probs.pkl", "wb") as f:
    pickle.dump(bigram_probs, f)

## Group bigram counts by 1st token

In [10]:
with open("tinystories_ngrams/bigram_counts.json", "r") as f:
    bigram_counts = json.load(f)

with open("tinystories_ngrams/unigram_counts.json", "r") as f:
    unigram_counts = json.load(f)

with open("tinystories_ngrams/vocab.json", "r") as f:
    vocab = json.load(f)

In [13]:
from tqdm import tqdm

In [14]:
bigram_counts_by_unigram = {}
for unigram in unigram_counts.keys():
    bigram_counts_by_unigram[unigram] = {}

for unigram in tqdm(unigram_counts.keys()):
    for token in vocab.keys():
        bigram = f"{unigram},{token}"
        if bigram in bigram_counts:
            bigram_counts_by_unigram[unigram][token] = bigram_counts[bigram]

  0%|          | 0/24346 [00:00<?, ?it/s]

100%|██████████| 24346/24346 [11:02<00:00, 36.74it/s]


In [None]:
with open("tinystories_ngrams/bigram_counts_by_unigram.json", "w") as f:
    json.dump(bigram_counts_by_unigram, f, indent=2, ensure_ascii=False)

## Generate Expected Rewards for Next Token (using 1st token and 1st two tokens)

In [19]:
import pickle
import json

with open("tinystories_ngrams/trigram_probs.pkl", "rb") as f:
    trigram_probs = pickle.load(f)["trigram_probs"]
with open("tinystories_ngrams/bigram_probs.pkl", "rb") as f:
    bigram_probs = pickle.load(f)

with open("tinystories_ngrams/trigram_counts.json", "rb") as f:
    trigram_counts = json.load(f)
with open("tinystories_ngrams/bigram_counts_by_unigram.json", "rb") as f:
    bigram_counts_by_unigram = json.load(f)

In [20]:
import math
from tqdm import tqdm

In [None]:
expected_rewards = {}

for bigram in tqdm(trigram_counts.keys()):
    for token in trigram_counts[bigram].keys():
        prob = trigram_probs.get(f"{bigram},{token}", 0)
        if prob == 0:
            continue
        reward = prob * (-math.log(prob))
        expected_rewards[bigram] = expected_rewards.get(bigram, 0) + reward

for unigram in tqdm(bigram_counts_by_unigram.keys()):
    for token in bigram_counts_by_unigram[unigram].keys():
        prob = bigram_probs.get(f"{unigram},{token}", 0)
        if prob == 0:
            continue
        reward = prob * (-math.log(prob))
        expected_rewards[unigram] = expected_rewards.get(unigram, 0) + reward

100%|██████████| 24346/24346 [00:04<00:00, 5895.54it/s] 


In [23]:
expected_rewards

{'One,Ġday': 0.6957748205788136,
 'Ġday,,': 3.6434930226621867,
 ',,Ġa': 4.5354780287545875,
 'Ġa,Ġlittle': 2.416975918355384,
 'Ġlittle,Ġgirl': 3.0042260911417022,
 'Ġgirl,Ġnamed': 2.101139163022167,
 'Ġnamed,ĠLily': 0.6342225621690725,
 'ĠLily,Ġfound': 1.162952574520215,
 'Ġfound,Ġa': 4.930036360330063,
 'Ġa,Ġneedle': 2.3289587361121065,
 'Ġneedle,Ġin': 2.1602553840450347,
 'Ġin,Ġher': 4.210069106505545,
 'Ġher,Ġroom': 2.1615689191751573,
 'Ġroom,.': 2.865417727849886,
 '.,ĠShe': 4.526950467023217,
 'ĠShe,Ġknew': 2.2792399136879427,
 'Ġknew,Ġit': 1.5141753640393256,
 'Ġit,Ġwas': 4.366212918758365,
 'Ġwas,Ġdifficult': 1.9951212188868483,
 'Ġdifficult,Ġto': 4.430455508433934,
 'Ġto,Ġplay': 2.3758468989622785,
 'Ġplay,Ġwith': 3.2386583785820755,
 'Ġwith,Ġit': 2.755509383690754,
 'Ġit,Ġbecause': 1.577947680104791,
 'Ġbecause,Ġit': 2.4116845298271805,
 'Ġwas,Ġsharp': 1.3608371614836838,
 'Ġsharp,.': 3.053336878257181,
 '.,ĠLily': 3.580670512256623,
 'ĠLily,Ġwanted': 0.39568924092670477,
 

In [25]:
with open("tinystories_ngrams/expected_rewards.pkl", "wb") as f:
    pickle.dump(expected_rewards, f)