In [1]:
import polars as pl

DATA_PATH = "../../data/"

train = pl.read_csv(DATA_PATH + "train.csv")
predicted_prompt = pl.read_csv(DATA_PATH + "predicted_prompt.csv")

In [2]:
train

essay_id,full_text,score
str,str,i64
"""000d118""","""Many people have car where the…",3
"""000fe60""","""I am a scientist at NASA that …",3
"""001ab80""","""People always wish they had th…",4
"""001bdc0""","""We all heard about Venus, the …",4
"""002ba53""","""Dear, State Senator This is a…",3
…,…,…
"""ffd378d""","""the story "" The Challenge of E…",2
"""ffddf1f""","""Technology has changed a lot o…",4
"""fff016d""","""If you don't like sitting arou…",2
"""fffb49b""","""In ""The Challenge of Exporing …",1


In [3]:
train = train.join(  # prompt_nameを付与する
    predicted_prompt.select(pl.col(["essay_id", "prompt_name"])),
    how="left",
    on="essay_id",
)

In [4]:
from datasets import load_dataset

# 追加データ
persuade_dataset = (
    load_dataset(
        "csv",
        data_files={"train": f"{DATA_PATH}/persuade_w_is_tr_con_as_num.csv"},
        split="train",
    )
    .filter(lambda x: not x["is_train_contains"])
    .select_columns(
        ["essay_id_comp", "full_text", "holistic_essay_score", "prompt_name"]
    )
    .rename_columns({"essay_id_comp": "essay_id", "holistic_essay_score": "score"})
)

persuade_df = pl.DataFrame(persuade_dataset.to_pandas())

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
train = pl.concat([train, persuade_df])

In [6]:
train = train.with_columns(  # prompt_nameをlabel encodingする
    pl.col("prompt_name").cast(pl.Categorical).to_physical().alias("prompt_id")
)

In [7]:
train = train.with_columns(
    pl.col("full_text")
    .map_elements(lambda x: len(x.split()), return_dtype=pl.Int64)
    .alias("word_length"),
)

In [8]:
train = train.with_columns(
    pl.when(pl.col("word_length") <= 200)
    .then(pl.lit("x<=200"))
    .when(pl.col("word_length") <= 400)
    .then(pl.lit("200<x<=400"))
    .when(pl.col("word_length") <= 600)
    .then(pl.lit("400<x<=600"))
    .when(pl.col("word_length") <= 800)
    .then(pl.lit("600<x<=800"))
    .when(pl.col("word_length") <= 1000)
    .then(pl.lit("800<x<=1000"))
    .when(pl.col("word_length") <= 1200)
    .then(pl.lit("1000<x<=1200"))
    .when(pl.col("word_length") > 1200)
    .then(pl.lit("x>1200"))
    .alias("word_length_cat"),
)

In [9]:
train = train.with_columns(
    pl.concat_str(
        pl.col("score"),
        pl.col("word_length_cat"),
        separator="_",
    ).alias("score_word_length_cat")
).with_columns(
    pl.concat_str(
        pl.col("score_word_length_cat"),
        pl.col("prompt_id"),
        separator="_",
    ).alias("concat_score_word_length_prompt_id")
)

In [10]:
train.head(10)

essay_id,full_text,score,prompt_name,prompt_id,word_length,word_length_cat,score_word_length_cat,concat_score_word_length_prompt_id
str,str,i64,str,u32,i64,str,str,str
"""000d118""","""Many people have car where the…",3,"""Car-free cities""",0,498,"""400<x<=600""","""3_400<x<=600""","""3_400<x<=600_0"""
"""000fe60""","""I am a scientist at NASA that …",3,"""The Face on Mars""",1,332,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_1"""
"""001ab80""","""People always wish they had th…",4,"""Driverless cars""",2,550,"""400<x<=600""","""4_400<x<=600""","""4_400<x<=600_2"""
"""001bdc0""","""We all heard about Venus, the …",4,"""Exploring Venus""",3,451,"""400<x<=600""","""4_400<x<=600""","""4_400<x<=600_3"""
"""002ba53""","""Dear, State Senator This is a…",3,"""Does the electoral college wor…",4,373,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_4"""
"""0030e86""","""If I were to choose between ke…",4,"""Does the electoral college wor…",4,400,"""200<x<=400""","""4_200<x<=400""","""4_200<x<=400_4"""
"""0033037""","""The posibilty of a face reconi…",2,"""Facial action coding system""",5,179,"""x<=200""","""2_x<=200""","""2_x<=200_5"""
"""0033bf4""","""What is the Seagoing Cowboys p…",3,"""""A Cowboy Who Rode the Waves""""",6,353,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_6"""
"""0036253""","""The challenge of exploring Ven…",2,"""Exploring Venus""",3,310,"""200<x<=400""","""2_200<x<=400""","""2_200<x<=400_3"""
"""0040e27""","""There are many reasons why you…",3,"""""A Cowboy Who Rode the Waves""""",6,280,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_6"""


In [11]:
# (
#     train.select(
#         pl.col("concat_score_word_length_prompt_id").value_counts(),
#     )
#     .unnest("concat_score_word_length_prompt_id")
#     .with_columns((pl.col("count") / train.height).alias("percentage"))
#     .sort("count", descending=True)
# )

In [12]:
# promptを用いてstratifyを行なって、分割する
from sklearn.model_selection import train_test_split

SPLIT_SIZE = 0.5

train_1, train_2 = train_test_split(
    train, test_size=SPLIT_SIZE, random_state=42, stratify=train["prompt_id"]
)

In [13]:
def get_st_slp(data):
    import numpy as np
    from sklearn.model_selection import StratifiedKFold

    fold_arr = np.zeros(data.height)
    sgkf = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)

    for idx, (_, val_idx) in enumerate(
        sgkf.split(data, data["concat_score_word_length_prompt_id"])
    ):
        fold_arr[val_idx] = idx

    return data.with_columns(pl.Series(fold_arr).cast(pl.Int64).alias("fold"))


train_1 = get_st_slp(train_1)



In [14]:
def get_st_sl_g_p(data):
    import numpy as np
    from sklearn.model_selection import StratifiedGroupKFold

    fold_arr = np.zeros(data.height)
    sgkf = StratifiedGroupKFold(n_splits=3, random_state=42, shuffle=True)

    for idx, (_, val_idx) in enumerate(
        sgkf.split(data, data["score_word_length_cat"], data["prompt_id"])
    ):
        fold_arr[val_idx] = idx

    return data.with_columns(pl.Series(fold_arr).cast(pl.Int64).alias("fold"))


train_2 = get_st_sl_g_p(train_2)



In [15]:
train = pl.concat([train_1, train_2])

In [16]:
train.head(10)

essay_id,full_text,score,prompt_name,prompt_id,word_length,word_length_cat,score_word_length_cat,concat_score_word_length_prompt_id,fold
str,str,i64,str,u32,i64,str,str,str,i64
"""CB8B86AD7823""","""Online learning is a great opp…",5,"""Distance learning""",13,634,"""600<x<=800""","""5_600<x<=800""","""5_600<x<=800_13""",1
"""45B098D3A19B""","""We all know how important safe…",3,"""Phones and driving""",7,277,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_7""",0
"""15f01fe""","""Many people think that the fac…",3,"""The Face on Mars""",1,241,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_1""",1
"""250aece""","""Does your expression in the mi…",1,"""Facial action coding system""",5,256,"""200<x<=400""","""1_200<x<=400""","""1_200<x<=400_5""",0
"""2E9E0F46CD6A""","""Dear TEACHER_NAME, Cell phone…",5,"""Cell phones at school""",12,590,"""400<x<=600""","""5_400<x<=600""","""5_400<x<=600_12""",2
"""09d065f""","""How Seagoing Cowboys is a very…",3,"""""A Cowboy Who Rode the Waves""""",6,317,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_6""",2
"""CC067F0DC876""","""To whom ever this may concern …",3,"""Community service""",10,213,"""200<x<=400""","""3_200<x<=400""","""3_200<x<=400_10""",2
"""BE1E07640466""","""Dear Ms. principal My name is…",2,"""Grades for extracurricular act…",11,191,"""x<=200""","""2_x<=200""","""2_x<=200_11""",1
"""2916425""","""Dear state senator, After res…",5,"""Does the electoral college wor…",4,433,"""400<x<=600""","""5_400<x<=600""","""5_400<x<=600_4""",0
"""1ed0f53""","""Have you ever been sitting in …",2,"""Facial action coding system""",5,214,"""200<x<=400""","""2_200<x<=400""","""2_200<x<=400_5""",0


In [17]:
essay_id_fold_dict = dict(zip(train["essay_id"], train["fold"]))

import json

with open("essay_id_fold_by_half_st_slp_and_st_sl_g_p.json", "w") as f:
    json.dump(essay_id_fold_dict, f)

# Check

In [18]:
(
    train_1.filter(pl.col("fold") == 0)
    .select(
        pl.col("score_word_length_cat").value_counts(),
    )
    .unnest("score_word_length_cat")
    .with_columns((pl.col("count") / train.height).alias("percentage"))
    .sort("count", descending=True)
    .head(10)
)

score_word_length_cat,count,percentage
str,u32,f64
"""3_200<x<=400""",1242,0.040812
"""4_400<x<=600""",818,0.02688
"""2_200<x<=400""",686,0.022542
"""4_200<x<=400""",382,0.012553
"""2_x<=200""",363,0.011928
"""3_400<x<=600""",350,0.011501
"""5_600<x<=800""",239,0.007854
"""5_400<x<=600""",236,0.007755
"""1_200<x<=400""",133,0.00437
"""4_600<x<=800""",120,0.003943


In [19]:
(
    train_1.filter(pl.col("fold") == 1)
    .select(
        pl.col("score_word_length_cat").value_counts(),
    )
    .unnest("score_word_length_cat")
    .with_columns((pl.col("count") / train.height).alias("percentage"))
    .sort("count", descending=True)
    .head(10)
)

score_word_length_cat,count,percentage
str,u32,f64
"""3_200<x<=400""",1243,0.040845
"""4_400<x<=600""",820,0.026945
"""2_200<x<=400""",688,0.022608
"""4_200<x<=400""",381,0.01252
"""2_x<=200""",362,0.011895
"""3_400<x<=600""",346,0.01137
"""5_600<x<=800""",239,0.007854
"""5_400<x<=600""",234,0.007689
"""1_200<x<=400""",133,0.00437
"""4_600<x<=800""",121,0.003976


In [20]:
(
    train_1.filter(pl.col("fold") == 0)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
2,583,0.038315
5,507,0.03332
3,502,0.032992
13,358,0.023528
1,349,0.022936
…,…,…
12,274,0.018007
11,270,0.017744
14,259,0.017022
10,259,0.017022


In [21]:
(
    train_1.filter(pl.col("fold") == 1)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_1.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
2,583,0.038315
5,504,0.033123
3,502,0.032992
13,359,0.023594
1,349,0.022936
…,…,…
11,274,0.018007
6,274,0.018007
14,258,0.016956
10,257,0.01689


In [22]:
(
    train_2.filter(pl.col("fold") == 0)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_2.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
3,1508,0.099106
1,1047,0.068809
6,825,0.054219
10,771,0.05067
7,584,0.038381


In [23]:
(
    train_2.filter(pl.col("fold") == 1)
    .select(
        pl.col("prompt_id").value_counts(),
    )
    .unnest("prompt_id")
    .with_columns((pl.col("count") / train_2.height).alias("percentage"))
    .sort("count", descending=True)
)

prompt_id,count,percentage
u32,u32,f64
2,1748,0.114879
4,1023,0.067232
8,875,0.057505
11,813,0.053431
14,776,0.050999
