In [1]:
from pathlib import Path
import pandas as pd

path_input_csv = Path("../../input/santa-2024/sample_submission.csv")
path_save = Path("./save")
path_save.mkdir(parents=True, exist_ok=True)
path_model = Path("../../input/gemma-2/")

df = pd.read_csv(path_input_csv)


In [2]:
from evaluation import PerplexityCalculator

calculator = PerplexityCalculator(model_path=str(path_model))

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:07<00:00,  1.05it/s]


In [3]:
import numpy as np
import tqdm

score_memo = {}


def get_perplexity(words):
    if words in score_memo:
        return score_memo[words]
    score = calculator.get_perplexity(words)
    score_memo[words] = score
    return score


def save_text(n_idx, text, verbose=0):
    path_save_idx = path_save / f"{n_idx:04d}"
    if not path_save_idx.exists():
        path_save_idx.mkdir()
    score = calculator.get_perplexity(text)
    if verbose:
        print(f"score:{score:.4f}")
    path_save_text = path_save_idx / f"{score:.4f}.txt"

    with path_save_text.open("w") as f:
        f.write(text)


def greedy(n_idx):
    text = df.iloc[n_idx, 1]
    words = text.split()

    n_words = len(words)
    print(f"number of words: {n_words}")

    state = [([word], get_perplexity(word)) for word in words]

    for _ in tqdm.trange(n_words - 1):
        assert 1 <= len(state) <= n_words
        gain_best = -np.inf
        size_state = len(state)
        idx_pair_best = None
        words_new = None
        score_new = None
        for i in range(size_state):
            for j in range(size_state):
                if i == j:
                    continue

                words1, score1 = state[i]
                words2, score2 = state[j]

                for k in range(len(words1)):
                    words12 = words1[:k] + words2 + words1[k:]
                    score12 = get_perplexity(" ".join(words12))

                    gain = score1 + score2 - score12

                    if gain > gain_best:
                        gain_best = gain
                        idx_pair_best = (i, j)
                        words_new = words12
                        score_new = score12

        assert idx_pair_best is not None

        state_nxt = []
        for i in range(size_state):
            if i in idx_pair_best:
                continue
            state_nxt.append(state[i])
        state_nxt.append((words_new, score_new))
        state = state_nxt

        score_total = sum(score for _, score in state)

        print(f"gain: {gain_best}")
        print(f"total score: {score_total:.4f}")
        print(f"{words1} + {words2}")

    assert len(state) == 1
    words, score = state[0]

    # for i in tqdm.trange(n_words):
    #     # best word and best place to insert
    #     score_best = np.inf
    #     state_best = None
    #     word_best = None

    #     for word in words_unused:
    #         for i in range(len(state) + 1):
    #             state_new = state[:i] + [word] + state[i:]
    #             score = calculator.get_perplexity(" ".join(state_new))

    #             if score < score_best:
    #                 score_best = score
    #                 state_best = state_new
    #                 word_best = word

    #     assert state_best is not None
    #     print(f"best score: {score_best}")
    #     print(f"added word: {word_best}")

    #     state = state_best
    #     words_unused.remove(word_best)
    #     words_used.add(word_best)

    #     print(state)

    # score = calculator.get_perplexity(" ".join(state))

    return words, score


for n_idx in range(len(df)):
    words, score = greedy(n_idx)
    save_text(n_idx, " ".join(words), verbose=1)

number of words: 10


 11%|█         | 1/9 [00:14<01:57, 14.70s/it]

gain: 9748.206139080066
total score: 282.9190
['scrooge'] + ['reindeer']


 22%|██▏       | 2/9 [00:18<00:58,  8.36s/it]

gain: 196.35273349210138
total score: 86.5662
['gingerbread', 'mistletoe'] + ['scrooge']


 33%|███▎      | 3/9 [00:23<00:39,  6.66s/it]

gain: 55.00916866027889
total score: 31.5571
['chimney', 'gingerbread', 'mistletoe'] + ['scrooge']


 44%|████▍     | 4/9 [00:28<00:30,  6.03s/it]

gain: 7.77959538873674
total score: 23.7775
['chimney', 'gingerbread', 'mistletoe', 'fireplace'] + ['scrooge']


 56%|█████▌    | 5/9 [00:31<00:19,  4.92s/it]

gain: 7.7497407373801614
total score: 16.0277
['ornament', 'reindeer'] + ['chimney', 'gingerbread', 'mistletoe', 'fireplace']


 67%|██████▋   | 6/9 [00:36<00:14,  4.85s/it]

gain: 3.27100360895854
total score: 12.7567
['chimney', 'gingerbread', 'mistletoe', 'ornament', 'reindeer', 'fireplace'] + ['scrooge']


 78%|███████▊  | 7/9 [00:38<00:08,  4.04s/it]

gain: 1.2128709348376043
total score: 11.5439
['scrooge', 'family'] + ['chimney', 'gingerbread', 'mistletoe', 'ornament', 'reindeer', 'fireplace']


 89%|████████▉ | 8/9 [00:40<00:03,  3.46s/it]

gain: 2.073484730251876
total score: 9.4704
['scrooge', 'advent', 'family'] + ['chimney', 'gingerbread', 'mistletoe', 'ornament', 'reindeer', 'fireplace']


100%|██████████| 9/9 [00:42<00:00,  4.70s/it]

gain: -26.04280308876031
total score: 35.5132
['scrooge', 'advent', 'elf', 'family'] + ['chimney', 'gingerbread', 'mistletoe', 'ornament', 'reindeer', 'fireplace']
score:35.5132
number of words: 20



  5%|▌         | 1/19 [00:46<14:01, 46.77s/it]

gain: 9873.899063430837
total score: 1425.5060
['and'] + ['laugh']


 11%|█         | 2/19 [00:55<06:55, 24.46s/it]

gain: 1199.893818598337
total score: 225.6122
['mistletoe', 'walk'] + ['and']


 16%|█▌        | 3/19 [01:06<04:54, 18.38s/it]

gain: 149.80945339656668
total score: 75.8028
['mistletoe', 'walk', 'gingerbread'] + ['and']


 21%|██        | 4/19 [01:19<04:04, 16.33s/it]

gain: 30.76688144286918
total score: 45.0359
['ornament', 'mistletoe', 'walk', 'gingerbread'] + ['and']


 26%|██▋       | 5/19 [01:34<03:40, 15.77s/it]

gain: 9.672552744468215
total score: 35.3633
['ornament', 'mistletoe', 'walk', 'gingerbread', 'fireplace'] + ['and']


 32%|███▏      | 6/19 [01:42<02:48, 12.96s/it]

gain: 6.944952848718014
total score: 28.4184
['drive', 'chimney'] + ['ornament', 'mistletoe', 'walk', 'gingerbread', 'fireplace']


 37%|███▋      | 7/19 [01:57<02:44, 13.68s/it]

gain: 3.511791396434983
total score: 24.9066
['ornament', 'mistletoe', 'reindeer', 'walk', 'gingerbread', 'fireplace'] + ['drive', 'chimney']


 42%|████▏     | 8/19 [02:15<02:47, 15.23s/it]

gain: 3.27100360895854
total score: 21.6356
['ornament', 'mistletoe', 'reindeer', 'walk', 'gingerbread', 'drive', 'chimney', 'fireplace'] + ['and']


 47%|████▋     | 9/19 [02:22<02:04, 12.44s/it]

gain: 2.56902179081825
total score: 19.0666
['scrooge', 'family'] + ['ornament', 'mistletoe', 'reindeer', 'walk', 'gingerbread', 'drive', 'chimney', 'fireplace']


 53%|█████▎    | 10/19 [02:40<02:07, 14.19s/it]

gain: 1.985804019979219
total score: 17.0808
['ornament', 'mistletoe', 'reindeer', 'walk', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf'] + ['scrooge', 'family']


 58%|█████▊    | 11/19 [02:58<02:02, 15.29s/it]

gain: 1.6266202846294455
total score: 15.4541
['ornament', 'mistletoe', 'reindeer', 'walk', 'and', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf'] + ['scrooge', 'family']


 63%|██████▎   | 12/19 [03:03<01:26, 12.42s/it]

gain: 1.3763702571682357
total score: 14.0778
['the', 'jump'] + ['ornament', 'mistletoe', 'reindeer', 'walk', 'and', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf']


 68%|██████▊   | 13/19 [03:10<01:04, 10.76s/it]

gain: 1.253568611489201
total score: 12.8242
['the', 'night', 'jump'] + ['ornament', 'mistletoe', 'reindeer', 'walk', 'and', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf']


 74%|███████▎  | 14/19 [03:24<00:57, 11.59s/it]

gain: 1.2612123861530362
total score: 11.5630
['ornament', 'mistletoe', 'reindeer', 'walk', 'give', 'and', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf'] + ['the', 'night', 'jump']


 79%|███████▉  | 15/19 [03:37<00:48, 12.08s/it]

gain: 1.0451404987059711
total score: 10.5178
['ornament', 'mistletoe', 'reindeer', 'walk', 'give', 'and', 'advent', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf'] + ['the', 'night', 'jump']


 84%|████████▍ | 16/19 [03:48<00:35, 11.75s/it]

gain: 0.8758411440416689
total score: 9.6420
['ornament', 'mistletoe', 'reindeer', 'walk', 'sleep', 'give', 'and', 'advent', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf'] + ['the', 'night', 'jump']


 89%|████████▉ | 17/19 [03:54<00:19,  9.85s/it]

gain: 0.6477063266914289
total score: 8.9943
['the', 'bake', 'night', 'jump'] + ['ornament', 'mistletoe', 'reindeer', 'walk', 'sleep', 'give', 'and', 'advent', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf']


 95%|█████████▍| 18/19 [04:01<00:09,  9.18s/it]

gain: -0.23295132471842095
total score: 9.2273
['ornament', 'the', 'bake', 'night', 'jump', 'mistletoe', 'reindeer', 'walk', 'sleep', 'give', 'and', 'advent', 'gingerbread', 'drive', 'chimney', 'fireplace', 'elf'] + ['scrooge', 'family']


100%|██████████| 19/19 [04:05<00:00, 12.93s/it]

gain: 0.47306274872681797
total score: 8.7542
['ornament', 'the', 'bake', 'night', 'jump', 'mistletoe', 'reindeer', 'walk', 'sleep', 'give', 'and', 'advent', 'gingerbread', 'drive', 'chimney', 'laugh', 'fireplace', 'elf'] + ['scrooge', 'family']
score:8.7542
number of words: 20



  0%|          | 0/19 [00:00<?, ?it/s]

In [None]:
print(words)