# 最初と最後を固定するSA

### パッケージ

In [1]:
ex_num = "002"

In [2]:
import os
# 環境によって変更
# os.chdir('/home/jovyan/work/notebook')

path = "../../output"
# path = "../model"

In [3]:
import numpy as np
import pandas as pd
import os
import time
import csv

from tqdm.notebook import tqdm

import random
from typing import List, Callable
import itertools, math
from typing import List, Tuple

from utils import PerplexityCalculator

### データ読み込み

In [4]:
df = pd.read_csv("../input/sample_submission.csv")

In [5]:
sub_df = df.copy()
sub_df["best_value"] = 0

In [6]:
df

Unnamed: 0,id,text
0,0,advent chimney elf family fireplace gingerbrea...
1,1,advent chimney elf family fireplace gingerbrea...
2,2,yuletide decorations gifts cheer holiday carol...
3,3,yuletide decorations gifts cheer holiday carol...
4,4,hohoho candle poinsettia snowglobe peppermint ...
5,5,advent chimney elf family fireplace gingerbrea...


### 関数

In [7]:
def simulated_annealing(
    scorer,
    pattern_df,
    max_iterations: int = 1000,
    BATCH_SIZE: int = 32,
    initial_temperature: float = 100.0,
    cooling_rate: float = 0.99
) -> Tuple[List[str], float]:
    """
    焼きなまし法で最適な順序を探索する。

    Args:
        scorer: 評価関数を提供するオブジェクト。
        pattern_df: パターンのデータフレーム。
        max_iterations (int): 最大反復回数。
        BATCH_SIZE (int): バッチサイズ。
        initial_temperature (float): 初期温度。
        cooling_rate (float): 冷却率。

    Returns:
        solution_df: 最適な順序のデータフレーム。
    """
    
    solution_df = pattern_df.copy()
    solution_df["score"] = solution_df["score"].astype(float)
    
    words_len = len(solution_df.iloc[0]["text"].split())
    pattern_len = len(solution_df)
    
    temperature = initial_temperature
    
    perms = []

    for iteration in tqdm(range(max_iterations)):
        # 1行ずつ処理する
        for idx in range(pattern_len):
            current_sequence = solution_df.iloc[idx]["text"].split()
            
            neighbor_sequence = current_sequence[:]
            # 最初と最後の単語を固定し、ランダムに交換する
            i, j = random.sample(range(1, words_len - 1), 2)
            neighbor_sequence[i], neighbor_sequence[j] = neighbor_sequence[j], neighbor_sequence[i]
            
            perms.append(" ".join(neighbor_sequence))
        
        # 評価値をまとめて計算
        neighbor_values = scorer.get_perplexity(perms, batch_size=min(pattern_len, BATCH_SIZE))

        # print(f"neighbor_values = {neighbor_values}")
        # 焼きなましによる評価値更新
        for idx in range(pattern_len):
            perm = perms[idx]
            neighbor_value = neighbor_values[idx]
            current_value = solution_df.iloc[idx]["score"]
            # print(f"current_value = {current_value}, neighbor_value = {neighbor_value}")
            
            # 確率的に次の状態を受け入れるかを決定
            delta = neighbor_value - current_value
            acceptance_probability = math.exp(-delta / temperature) if delta > 0 else 1.0

            # ベスト解を更新
            if (random.random() < acceptance_probability) or (neighbor_value < current_value):
                current_sequence = perm.split(" ")
                current_value = neighbor_value
                
                solution_df.loc[solution_df.index[idx], "text"] = perm
                solution_df.loc[solution_df.index[idx], "score"] = current_value
                start_word = solution_df.iloc[idx]["start"]
                end_word = solution_df.iloc[idx]["end"]
                # print(f"start_word = {start_word}, end_word = {end_word}, New best = {current_value} with '{current_sequence}'")

        # バッチリセット
        perms = []

        # 温度を更新
        if iteration % 10 == 0:
            temperature *= cooling_rate
                    
        # 温度が非常に低くなったら終了
        if temperature < 1e-5:
            break

        # iteration100回ごとに現在の温度を表示
        if iteration % 100 == 0:
            print(f"Iteration {iteration}, Temperature = {temperature}")
            
    return solution_df
     

In [8]:
def search_optimal_permutation(id: int, scorer, max_iterations: int = 1000, BATCH_SIZE=64):
    
    df = pd.read_csv("../input/sample_submission.csv")
    
    # 文字列を単語に分割
    words = df.loc[id,"text"].split()
    
    # データをタプルのリストとして準備
    pattern_data = []
    
    words_len = len(words)
    for i in range(0, words_len):
        for j in range(0, words_len):
            if i == j:
                continue
            new_words = words.copy()
            new_words[i], new_words[0] = new_words[0], new_words[i]
            
            if j == 0:
                new_words[i], new_words[words_len-1] = new_words[words_len-1], new_words[i]
            else:
                new_words[j], new_words[words_len-1] = new_words[words_len-1], new_words[j]
                
            # new_wordsのs番目からt番目までの要素をランダムな順番にする
            new_words[1:words_len-1] = random.sample(new_words[1:words_len-1], len(new_words[1:words_len-1]))
    
            # データを追加
            pattern_data.append((new_words[0], new_words[words_len-1], " ".join(new_words), 10000))
            
    pattern_df = pd.DataFrame(pattern_data, columns=["start", "end", "text", "score"])        

    start = time.time()
    
    # SAで最適な順序を探索
    solutions_df = simulated_annealing(
        scorer,
        pattern_df,
        max_iterations=max_iterations,
        BATCH_SIZE=BATCH_SIZE,
        initial_temperature=100.0,
        cooling_rate=0.99
    )
    
    # optimized_dfのscore列の値が最小の行を取得
    best_sequence = solutions_df.loc[solutions_df["score"].idxmin(), "text"]
    best_value = solutions_df["score"].min()
    
    print(f"{id}th sample: {best_sequence}")
    print(f"Elapsed time: {time.time() - start:.2f} sec")
    print(f"Best value: {best_value}")
    
    return best_sequence, best_value, solutions_df


In [9]:
# 文字列型にのみダブルクォーテーションを付ける関数
def add_quotes_to_strings(value):
    if isinstance(value, str):  # 文字列型の場合
        return f'"{value}"'
    return value  # それ以外はそのまま

In [10]:
# LOAD GEMMA SCORER
scorer = PerplexityCalculator(f'{path}/gemma_2_9b')

cuda


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [11]:
# scorer.clear_gpu_memory()

In [12]:
for i in [1]:
    
    print(f"Processing {i}th sample")
    best_sequence, best_value, solutions_df = search_optimal_permutation(
        id=i,
        scorer=scorer,
        max_iterations=2**15,
        BATCH_SIZE=512
    )
    
    print(best_sequence)
    sub_df.loc[i, "best_value"] = best_value
    sub_df.loc[i, "text"] = best_sequence
    
    display(solutions_df)
    solutions_df.to_csv(f"{path}/out/solutions_{i}_{ex_num}.csv", index=False, header=True)



Processing 1th sample


  0%|          | 0/8192 [00:00<?, ?it/s]

Iteration 0, Temperature = 99.0
Iteration 100, Temperature = 36.23720178604971
Iteration 200, Temperature = 13.263987810938206
Iteration 300, Temperature = 4.8550485130572865
Iteration 400, Temperature = 1.7771047742294686
Iteration 500, Temperature = 0.6504778211990457
Iteration 600, Temperature = 0.23809591983979556
Iteration 700, Temperature = 0.08715080698656348
Iteration 800, Temperature = 0.031900013925143154
Iteration 900, Temperature = 0.011676436783668767
Iteration 1000, Temperature = 0.00427395349365513
Iteration 1100, Temperature = 0.0015644052037754848
Iteration 1200, Temperature = 0.000572622899437964
Iteration 1300, Temperature = 0.00020959850054794293
Iteration 1400, Temperature = 7.671982988292172e-05
Iteration 1500, Temperature = 2.8081938954129706e-05
Iteration 1600, Temperature = 1.0278897862872002e-05
1th sample: reindeer mistletoe elf gingerbread ornament advent fireplace chimney sleep drive walk jump bake give laugh family night and the scrooge
Elapsed time: 1868.

  sub_df.loc[i, "best_value"] = best_value


Unnamed: 0,start,end,text,score
0,advent,chimney,advent mistletoe jump walk the night drive giv...,1307.381094
1,advent,elf,advent scrooge walk gingerbread bake ornament ...,1190.384032
2,advent,family,advent the elf scrooge walk gingerbread drive ...,1595.591830
3,advent,fireplace,advent jump scrooge walk gingerbread drive sle...,1343.623033
4,advent,gingerbread,advent elf sleep walk the family reindeer jump...,1327.969350
...,...,...,...,...
375,and,bake,and the family sleep scrooge walk night mistle...,1010.265807
376,and,the,and laugh reindeer jump elf walk gingerbread b...,1836.514940
377,and,sleep,and mistletoe scrooge family walk laugh jump t...,1528.483353
378,and,night,and laugh walk drive jump sleep mistletoe give...,1338.384743


Processing 2th sample


  0%|          | 0/8192 [00:00<?, ?it/s]

Iteration 0, Temperature = 99.0
Iteration 100, Temperature = 36.23720178604971
Iteration 200, Temperature = 13.263987810938206
Iteration 300, Temperature = 4.8550485130572865
Iteration 400, Temperature = 1.7771047742294686
Iteration 500, Temperature = 0.6504778211990457
Iteration 600, Temperature = 0.23809591983979556
Iteration 700, Temperature = 0.08715080698656348
Iteration 800, Temperature = 0.031900013925143154
Iteration 900, Temperature = 0.011676436783668767
Iteration 1000, Temperature = 0.00427395349365513
Iteration 1100, Temperature = 0.0015644052037754848
Iteration 1200, Temperature = 0.000572622899437964
Iteration 1300, Temperature = 0.00020959850054794293
Iteration 1400, Temperature = 7.671982988292172e-05
Iteration 1500, Temperature = 2.8081938954129706e-05
Iteration 1600, Temperature = 1.0278897862872002e-05
2th sample: jingle yuletide carol holiday cheer gifts stocking magi naughty nice grinch chimney sleigh polar beard workshop nutcracker ornament decorations holly
Elaps

Unnamed: 0,start,end,text,score
0,yuletide,decorations,yuletide magi holiday grinch carol cheer sleig...,605.618349
1,yuletide,gifts,yuletide holiday decorations holly jingle work...,649.733644
2,yuletide,cheer,yuletide workshop grinch magi holly jingle sle...,557.921767
3,yuletide,holiday,yuletide cheer jingle magi polar beard carol h...,665.141633
4,yuletide,carol,yuletide magi grinch holly jingle holiday chee...,564.498372
...,...,...,...,...
375,nice,ornament,nice nutcracker beard holly yuletide holiday c...,1408.104848
376,nice,holly,nice nutcracker magi carol yuletide grinch sle...,1458.489004
377,nice,jingle,nice beard yuletide cheer grinch magi carol sl...,1199.720330
378,nice,beard,nice nutcracker yuletide cheer grinch holly ji...,1118.262309


Processing 3th sample


  0%|          | 0/8192 [00:00<?, ?it/s]

Iteration 0, Temperature = 99.0
Iteration 100, Temperature = 36.23720178604971
Iteration 200, Temperature = 13.263987810938206
Iteration 300, Temperature = 4.8550485130572865
Iteration 400, Temperature = 1.7771047742294686
Iteration 500, Temperature = 0.6504778211990457
Iteration 600, Temperature = 0.23809591983979556
Iteration 700, Temperature = 0.08715080698656348
Iteration 800, Temperature = 0.031900013925143154
Iteration 900, Temperature = 0.011676436783668767
Iteration 1000, Temperature = 0.00427395349365513
Iteration 1100, Temperature = 0.0015644052037754848
Iteration 1200, Temperature = 0.000572622899437964
Iteration 1300, Temperature = 0.00020959850054794293
Iteration 1400, Temperature = 7.671982988292172e-05
Iteration 1500, Temperature = 2.8081938954129706e-05
Iteration 1600, Temperature = 1.0278897862872002e-05
3th sample: of yuletide cheer unwrap gifts eat holiday cheer relax sing carol the magi visit workshop naughty nice is and grinch chimney stocking decorations holly jin

Unnamed: 0,start,end,text,score
0,yuletide,decorations,yuletide cheer cheer unwrap gifts jingle and s...,384.954444
1,yuletide,gifts,yuletide of holiday is holly unwrap jingle bea...,540.756319
2,yuletide,cheer,yuletide cheer unwrap gifts relax eat and visi...,409.781873
3,yuletide,holiday,yuletide cheer unwrap the magi jingle sleigh h...,530.297145
4,yuletide,carol,yuletide cheer holiday cheer polar beard jingl...,479.083546
...,...,...,...,...
865,unwrap,the,unwrap visit yuletide eat cheer jingle relax g...,768.569696
866,unwrap,is,unwrap the holiday cheer gifts magi carol holl...,680.915012
867,unwrap,eat,unwrap grinch and yuletide the holly magi chee...,884.618298
868,unwrap,visit,unwrap ornament yuletide carol cheer holly jin...,705.279198


In [13]:
sub_df

Unnamed: 0,id,text,best_value
0,0,advent chimney elf family fireplace gingerbrea...,0.0
1,1,reindeer mistletoe elf gingerbread ornament ad...,600.905389
2,2,jingle yuletide carol holiday cheer gifts stoc...,346.421503
3,3,of yuletide cheer unwrap gifts eat holiday che...,296.310014
4,4,hohoho candle poinsettia snowglobe peppermint ...,0.0
5,5,advent chimney elf family fireplace gingerbrea...,0.0


In [14]:
# 各セルに関数を適用
sub_df["text"] = sub_df["text"].astype(str)
sub_df = sub_df.applymap(add_quotes_to_strings)

  sub_df = sub_df.applymap(add_quotes_to_strings)


In [15]:
sub_df[["id", "text"]].to_csv(f"{path}/out/submission_{ex_num}.csv", index=False, header=True, quoting=csv.QUOTE_NONE)
sub_df.to_csv(f"{path}/out/score_{ex_num}.csv", index=False, header=True, quoting=csv.QUOTE_NONE)

In [16]:
np.mean(sub_df["best_value"])

np.float64(207.27281765114648)