In [1]:
import polars as pl

DATA_PATH = "../../data/"

train = pl.read_csv(DATA_PATH + "train.csv")
predicted_prompt = pl.read_csv(DATA_PATH + "predicted_prompt.csv")

In [2]:
train

essay_id,full_text,score
str,str,i64
"""000d118""","""Many people have car where the…",3
"""000fe60""","""I am a scientist at NASA that …",3
"""001ab80""","""People always wish they had th…",4
"""001bdc0""","""We all heard about Venus, the …",4
"""002ba53""","""Dear, State Senator This is a…",3
…,…,…
"""ffd378d""","""the story "" The Challenge of E…",2
"""ffddf1f""","""Technology has changed a lot o…",4
"""fff016d""","""If you don't like sitting arou…",2
"""fffb49b""","""In ""The Challenge of Exporing …",1


In [3]:
train = train.join(  # prompt_nameを付与する
    predicted_prompt.select(pl.col(["essay_id", "prompt_name"])),
    how="left",
    on="essay_id",
)

In [4]:
from datasets import load_dataset

# 追加データ
persuade_dataset = (
    load_dataset(
        "csv",
        data_files={"train": f"{DATA_PATH}/persuade_w_is_tr_con_as_num.csv"},
        split="train",
    )
    .filter(lambda x: not x["is_train_contains"])
    .select_columns(
        ["essay_id_comp", "full_text", "holistic_essay_score", "prompt_name"]
    )
    .rename_columns({"essay_id_comp": "essay_id", "holistic_essay_score": "score"})
)

persuade_df = pl.DataFrame(persuade_dataset.to_pandas())

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# train = pl.concat([train, persuade_df])

In [6]:
train = train.with_columns(  # prompt_nameをlabel encodingする
    pl.col("prompt_name").cast(pl.Categorical).to_physical().alias("prompt_id")
)

persuade_df = persuade_df.with_columns(  # prompt_nameをlabel encodingする
    pl.col("prompt_name").cast(pl.Categorical).to_physical().alias("prompt_id")
)

In [7]:
train = train.with_columns(
    pl.col("full_text")
    .map_elements(lambda x: len(x.split()), return_dtype=pl.Int64)
    .alias("word_length"),
)

persuade_df = persuade_df.with_columns(
    pl.col("full_text")
    .map_elements(lambda x: len(x.split()), return_dtype=pl.Int64)
    .alias("word_length"),
)

In [8]:
train = train.with_columns(
    pl.when(pl.col("word_length") <= 200)
    .then(pl.lit("x<=200"))
    .when(pl.col("word_length") <= 400)
    .then(pl.lit("200<x<=400"))
    .when(pl.col("word_length") <= 600)
    .then(pl.lit("400<x<=600"))
    .when(pl.col("word_length") <= 800)
    .then(pl.lit("600<x<=800"))
    .when(pl.col("word_length") <= 1000)
    .then(pl.lit("800<x<=1000"))
    .when(pl.col("word_length") <= 1200)
    .then(pl.lit("1000<x<=1200"))
    .when(pl.col("word_length") > 1200)
    .then(pl.lit("x>1200"))
    .alias("word_length_cat"),
)

persuade_df = persuade_df.with_columns(
    pl.when(pl.col("word_length") <= 200)
    .then(pl.lit("x<=200"))
    .when(pl.col("word_length") <= 400)
    .then(pl.lit("200<x<=400"))
    .when(pl.col("word_length") <= 600)
    .then(pl.lit("400<x<=600"))
    .when(pl.col("word_length") <= 800)
    .then(pl.lit("600<x<=800"))
    .when(pl.col("word_length") <= 1000)
    .then(pl.lit("800<x<=1000"))
    .when(pl.col("word_length") <= 1200)
    .then(pl.lit("1000<x<=1200"))
    .when(pl.col("word_length") > 1200)
    .then(pl.lit("x>1200"))
    .alias("word_length_cat"),
)

In [9]:
train = train.with_columns(
    pl.concat_str(
        pl.col("score"),
        pl.col("word_length_cat"),
        separator="_",
    ).alias("score_word_length_cat")
).with_columns(
    pl.concat_str(
        pl.col("score_word_length_cat"),
        pl.col("prompt_id"),
        separator="_",
    ).alias("concat_score_word_length_prompt_id")
)

persuade_df = persuade_df.with_columns(
    pl.concat_str(
        pl.col("score"),
        pl.col("word_length_cat"),
        separator="_",
    ).alias("score_word_length_cat")
).with_columns(
    pl.concat_str(
        pl.col("score_word_length_cat"),
        pl.col("prompt_id"),
        separator="_",
    ).alias("concat_score_word_length_prompt_id")
)

In [11]:
train.head(3)

essay_id,full_text,score,prompt_name,prompt_id,word_length,word_length_cat,score_word_length_cat,concat_score_word_length_prompt_id
str,str,i64,str,u32,i64,str,str,str
"""000d118""","""Many people have car where the…",3,"""Car-free cities""",0,498,"""400<x<=600""","""3_400<x<=600""","""3_400<x<=600_0"""
"""000fe60""","""I am a scientist at NASA that …",3,"""The Face on Mars""",1,332,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_1"""
"""001ab80""","""People always wish they had th…",4,"""Driverless cars""",2,550,"""400<x<=600""","""4_400<x<=600""","""4_400<x<=600_2"""


In [11]:
# (
#     train.select(
#         pl.col("concat_score_word_length_prompt_id").value_counts(),
#     )
#     .unnest("concat_score_word_length_prompt_id")
#     .with_columns((pl.col("count") / train.height).alias("percentage"))
#     .sort("count", descending=True)
# )

In [13]:
# promptを用いてstratifyを行なって、分割する
from sklearn.model_selection import train_test_split

SPLIT_SIZE = 0.5

train_1, train_2 = train_test_split(
    train, test_size=SPLIT_SIZE, random_state=42, stratify=train["prompt_id"]
)

persuade_1, persuade_2 = train_test_split(
    persuade_df,
    test_size=SPLIT_SIZE,
    random_state=42,
    stratify=persuade_df["prompt_id"],
)

In [21]:
def get_st_slp(data):
    import numpy as np
    from sklearn.model_selection import StratifiedKFold

    fold_arr = np.zeros(data.height)
    sgkf = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)

    for idx, (_, val_idx) in enumerate(
        sgkf.split(data, data["concat_score_word_length_prompt_id"])
    ):
        fold_arr[val_idx] = idx

    return data.with_columns(pl.Series(fold_arr).cast(pl.Int64).alias("fold"))


train_1 = get_st_slp(train_1)
persuade_1 = get_st_slp(persuade_1).with_columns(pl.col("fold") + 3)



In [22]:
def get_st_sl_g_p(data):
    import numpy as np
    from sklearn.model_selection import StratifiedGroupKFold

    fold_arr = np.zeros(data.height)
    sgkf = StratifiedGroupKFold(n_splits=3, random_state=42, shuffle=True)

    for idx, (_, val_idx) in enumerate(
        sgkf.split(data, data["score_word_length_cat"], data["prompt_id"])
    ):
        fold_arr[val_idx] = idx

    return data.with_columns(pl.Series(fold_arr).cast(pl.Int64).alias("fold"))


train_2 = get_st_sl_g_p(train_2)
persuade_2 = get_st_sl_g_p(persuade_2).with_columns(pl.col("fold") + 3)



In [23]:
train = pl.concat([train_1, train_2])
persuade_df = pl.concat([persuade_1, persuade_2])

In [24]:
train.head(3)

essay_id,full_text,score,prompt_name,prompt_id,word_length,word_length_cat,score_word_length_cat,concat_score_word_length_prompt_id,fold
str,str,i64,str,u32,i64,str,str,str,i64
"""d4a8b29""","""I am for driverless cars, but …",3,"""Driverless cars""",2,324,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_2""",2
"""f2e66b1""","""I think the author did a good …",3,"""Exploring Venus""",3,271,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_3""",1
"""7e0cc6b""","""Can you detect a happy person …",3,"""Facial action coding system""",5,319,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_5""",0


In [25]:
persuade_df.head(3)

essay_id,full_text,score,prompt_name,prompt_id,word_length,word_length_cat,score_word_length_cat,concat_score_word_length_prompt_id,fold
str,str,i64,str,u32,i64,str,str,str,i64
"""2D23892ACDA0""","""Dear principal, I think that …",4,"""Community service""",4,463,"""400<x<=600""","""4_400<x<=600""","""4_400<x<=600_4""",3
"""5E8B50F6DF84""","""Going to school everyday can b…",5,"""Distance learning""",7,844,"""800<x<=1000""","""5_800<x<=1000""","""5_800<x<=1000_7""",4
"""D9C9F849DC3E""","""Dear principle, I think that …",3,"""Grades for extracurricular act…",5,532,"""400<x<=600""","""3_400<x<=600""","""3_400<x<=600_5""",4


In [26]:
essay_id_fold_dict = dict(zip(train["essay_id"], train["fold"]))

import json

with open("essay_id_fold_by_half_st_slp_and_st_sl_g_p_only_train.json", "w") as f:
    json.dump(essay_id_fold_dict, f)

In [27]:
essay_id_fold_dict = dict(zip(persuade_df["essay_id"], persuade_df["fold"]))

import json

with open("essay_id_fold_by_half_st_slp_and_st_sl_g_p_only_persuade.json", "w") as f:
    json.dump(essay_id_fold_dict, f)

# Check

In [28]:
(
    train_1.filter(pl.col("fold") == 0)
    .select(
        pl.col("score_word_length_cat").value_counts(),
    )
    .unnest("score_word_length_cat")
    .with_columns((pl.col("count") / train.height).alias("percentage"))
    .sort("count", descending=True)
    .head(10)
)

score_word_length_cat,count,percentage
str,u32,f64
"""3_200<x<=400""",750,0.043335
"""2_200<x<=400""",502,0.029006
"""4_400<x<=600""",420,0.024268
"""3_400<x<=600""",258,0.014907
"""2_x<=200""",225,0.013001
"""4_200<x<=400""",178,0.010285
"""1_200<x<=400""",124,0.007165
"""5_400<x<=600""",74,0.004276
"""1_x<=200""",67,0.003871
"""5_600<x<=800""",66,0.003813


In [29]:
(
    train_1.filter(pl.col("fold") == 1)
    .select(
        pl.col("score_word_length_cat").value_counts(),
    )
    .unnest("score_word_length_cat")
    .with_columns((pl.col("count") / train.height).alias("percentage"))
    .sort("count", descending=True)
    .head(10)
)

score_word_length_cat,count,percentage
str,u32,f64
"""3_200<x<=400""",749,0.043277
"""2_200<x<=400""",499,0.028832
"""4_400<x<=600""",421,0.024325
"""3_400<x<=600""",258,0.014907
"""2_x<=200""",229,0.013232
"""4_200<x<=400""",179,0.010343
"""1_200<x<=400""",124,0.007165
"""5_400<x<=600""",73,0.004218
"""5_600<x<=800""",68,0.003929
"""1_x<=200""",66,0.003813


In [30]:
(
    train_1.filter(pl.col("fold") == 0)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
2,583,0.067375
5,506,0.058477
3,506,0.058477
1,347,0.040102
4,340,0.039293
0,327,0.03779
6,276,0.031896


In [31]:
(
    train_1.filter(pl.col("fold") == 1)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
2,581,0.067144
5,507,0.058592
3,502,0.058015
1,352,0.04068
4,340,0.039293
0,326,0.037675
6,276,0.031896


In [32]:
(
    train_2.filter(pl.col("fold") == 0)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_2.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
3,1508,0.174255
1,1046,0.120869
0,981,0.113358


In [33]:
(
    train_2.filter(pl.col("fold") == 1)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_2.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
2,1749,0.202103
6,825,0.095332


In [35]:
(
    persuade_1.filter(pl.col("fold") == 3)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / persuade_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
7,358,0.054557
2,294,0.044803
3,280,0.04267
6,274,0.041756
5,270,0.041146
4,257,0.039165
8,257,0.039165
0,197,0.030021
1,1,0.000152


In [36]:
(
    persuade_1.filter(pl.col("fold") == 4)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / persuade_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
7,360,0.054861
2,289,0.044041
3,277,0.042213
6,277,0.042213
5,271,0.041298
8,263,0.040079
4,257,0.039165
0,193,0.029412


In [37]:
(
    persuade_2.filter(pl.col("fold") == 3)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / persuade_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
8,776,0.118257
0,584,0.088997


In [38]:
(
    persuade_2.filter(pl.col("fold") == 4)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / persuade_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
7,1079,0.164432
2,875,0.133343
5,813,0.123895
