In [1]:
DATA_PATH = "../../data"

In [2]:
import polars as pl

train = (
    pl.read_csv(f"{DATA_PATH}/train.csv")
    .with_columns(
        pl.col("prompt").str.json_decode(),
        pl.col("response_a").str.json_decode(),
        pl.col("response_b").str.json_decode(),
    )
    .with_columns(  # 長さの情報を追加する
        pl.col("prompt")
        .map_elements(lambda x: len(x), return_dtype=pl.Int64)
        .alias("len_prompt"),
        pl.col("response_a")
        .map_elements(lambda x: len(x), return_dtype=pl.Int64)
        .alias("len_response_a"),
        pl.col("response_b")
        .map_elements(lambda x: len(x), return_dtype=pl.Int64)
        .alias("len_response_b"),
    )
    .with_columns(  # 最後のレスポンスのみを取得する
        pl.col("prompt")
        .map_elements(lambda x: x[-1], return_dtype=pl.String)
        .alias("last_prompt"),
        pl.col("response_a")
        .map_elements(lambda x: x[-1], return_dtype=pl.String)
        .alias("last_response_a"),
        pl.col("response_b")
        .map_elements(lambda x: x[-1], return_dtype=pl.String)
        .alias("last_response_b"),
    )
    .with_columns(  # 最後のレスポンスがNoneの場合を空文字にする、約60件程度
        pl.col("last_response_a").fill_null(""),
        pl.col("last_response_b").fill_null(""),
    )
    .with_columns(  # labelを付与する
        pl.when(pl.col("winner_model_a") == 1)
        .then(0)
        .when(pl.col("winner_model_b") == 1)
        .then(1)
        .when(pl.col("winner_tie") == 1)
        .then(2)
        .alias("label"),
    )
    .select(  # 元のprompt, responseを削除する
        pl.exclude(["prompt", "response_a", "response_b"])
    )
)

In [3]:
import numpy as np
from sklearn.model_selection import StratifiedKFold

fold_arr = np.zeros(train.height)
sgkf = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)

for idx, (_, val_idx) in enumerate(
    sgkf.split(train, train["label"])
):
    fold_arr[val_idx] = idx

train = train.with_columns(pl.Series(fold_arr).cast(pl.Int64).alias("fold"))

In [14]:
# idとlabelをdictにして保存する
id_fold_dict = dict(zip(train["id"], train["fold"]))

In [21]:
# jsonで保存する
import json

with open(f"{DATA_PATH}/label_stratified_fold.json", "w") as f:
    json.dump(id_fold_dict, f, indent=4)

In [22]:
train["label"].value_counts()

label,count
i32,u32
1,19652
2,17761
0,20064


In [23]:
train["fold"].value_counts()

fold,count
i64,u32
1,19159
0,19159
2,19159


In [24]:
train.filter(
    pl.col("fold") == 0
)["label"].value_counts()

label,count
i32,u32
2,5920
1,6551
0,6688


In [25]:
train.filter(
    pl.col("fold") == 1
)["label"].value_counts()

label,count
i32,u32
1,6551
2,5920
0,6688


In [26]:
train.filter(
    pl.col("fold") == 2
)["label"].value_counts()

label,count
i32,u32
0,6688
2,5921
1,6550


In [29]:
id_fold_dict

{30192: 2,
 53567: 2,
 65089: 2,
 96401: 0,
 198779: 2,
 292873: 2,
 313413: 1,
 370945: 0,
 441448: 0,
 481524: 1,
 497862: 2,
 587904: 0,
 604575: 0,
 738614: 2,
 862324: 1,
 863398: 0,
 887722: 0,
 914644: 1,
 933555: 0,
 1120158: 2,
 1256092: 2,
 1404102: 0,
 1440765: 1,
 1458108: 2,
 1491225: 2,
 1594211: 1,
 1639617: 2,
 1744093: 1,
 1813737: 2,
 1827787: 0,
 1842252: 2,
 2051408: 1,
 2154496: 0,
 2298796: 1,
 2388511: 2,
 2802516: 0,
 2857714: 0,
 2912862: 0,
 2944182: 2,
 3254113: 0,
 3258431: 2,
 3259481: 0,
 3373963: 1,
 3445782: 0,
 3475655: 0,
 3499263: 0,
 3503031: 1,
 3504181: 2,
 3519254: 0,
 3567106: 2,
 3578663: 1,
 3590999: 1,
 3622781: 2,
 3643104: 2,
 3710170: 2,
 3760933: 2,
 3773792: 2,
 3777134: 0,
 3994811: 1,
 3995635: 0,
 4186011: 0,
 4349090: 2,
 4356730: 0,
 4486480: 2,
 4510489: 0,
 4587071: 2,
 4615863: 2,
 4683272: 0,
 4790276: 0,
 4961077: 2,
 4970917: 2,
 4990514: 2,
 5061737: 0,
 5069186: 0,
 5166668: 1,
 5187535: 1,
 5188727: 0,
 5378146: 2,
 5498037:

In [28]:
with open(f"{DATA_PATH}/label_stratified_fold.json", "r") as f:
    tmp = json.load(f)

In [30]:
tmp

{'30192': 2,
 '53567': 2,
 '65089': 2,
 '96401': 0,
 '198779': 2,
 '292873': 2,
 '313413': 1,
 '370945': 0,
 '441448': 0,
 '481524': 1,
 '497862': 2,
 '587904': 0,
 '604575': 0,
 '738614': 2,
 '862324': 1,
 '863398': 0,
 '887722': 0,
 '914644': 1,
 '933555': 0,
 '1120158': 2,
 '1256092': 2,
 '1404102': 0,
 '1440765': 1,
 '1458108': 2,
 '1491225': 2,
 '1594211': 1,
 '1639617': 2,
 '1744093': 1,
 '1813737': 2,
 '1827787': 0,
 '1842252': 2,
 '2051408': 1,
 '2154496': 0,
 '2298796': 1,
 '2388511': 2,
 '2802516': 0,
 '2857714': 0,
 '2912862': 0,
 '2944182': 2,
 '3254113': 0,
 '3258431': 2,
 '3259481': 0,
 '3373963': 1,
 '3445782': 0,
 '3475655': 0,
 '3499263': 0,
 '3503031': 1,
 '3504181': 2,
 '3519254': 0,
 '3567106': 2,
 '3578663': 1,
 '3590999': 1,
 '3622781': 2,
 '3643104': 2,
 '3710170': 2,
 '3760933': 2,
 '3773792': 2,
 '3777134': 0,
 '3994811': 1,
 '3995635': 0,
 '4186011': 0,
 '4349090': 2,
 '4356730': 0,
 '4486480': 2,
 '4510489': 0,
 '4587071': 2,
 '4615863': 2,
 '4683272': 0,
 '4