In [1]:
import os
import json
import gc
from datetime import datetime
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from tqdm import tqdm
from typing import Dict, List, Set, Tuple, NamedTuple, Callable
import scml
from scml import pandasx as pdx

In [2]:
version = "05"
files = [
    #Path("input/comp.parquet"),
    Path("input/persuade_02.parquet"),
]
n_splits = 15  # 15 topics exist in corpus
persuade_topic_classification = True
if persuade_topic_classification:
    n_splits = 25

In [3]:
tim = scml.Timer()
tim.start()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()
info = np.iinfo(np.int16)
print(f"int16, min={info.min}, max={info.max}")

int16, min=-32768, max=32767


In [4]:
cols = [
    "essay_id", 
    "score", 
    "ctq_Qwen2-1.5B-Instruct", 
    "ctq_3_Qwen2-1.5B-Instruct", 
    "topic", 
    "full_text"
]
cmb = None
for filepath in files:
    df = pd.read_parquet(filepath)
    df = df[cols]
    df["source"] = filepath.stem
    if cmb is None:
        cmb = df
    else:
        cmb = pd.concat([cmb, df], ignore_index=True)
cmb.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25996 entries, 0 to 25995
Data columns (total 7 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   essay_id                   25996 non-null  object
 1   score                      25996 non-null  int8  
 2   ctq_Qwen2-1.5B-Instruct    25996 non-null  object
 3   ctq_3_Qwen2-1.5B-Instruct  25996 non-null  object
 4   topic                      25996 non-null  object
 5   full_text                  25996 non-null  object
 6   source                     25996 non-null  object
dtypes: int8(1), object(6)
memory usage: 1.2+ MB


In [5]:
pdx.value_counts(df["topic"])

Unnamed: 0_level_0,count,percent
topic,Unnamed: 1_level_1,Unnamed: 2_level_1
facial action coding system,2167,0.083359
distance learning,2157,0.082974
electoral college work,2046,0.078704
car free cities,1959,0.075358
driverless cars,1886,0.07255
exploring venus,1862,0.071626
summer projects,1750,0.067318
mandatory extracurricular activities,1670,0.064241
cell phones school,1656,0.063702
grades extracurricular activities,1626,0.062548


# Train/Test Split

In [6]:
dummy = np.zeros(len(cmb))
if persuade_topic_classification:
    splitter = StratifiedKFold(n_splits=n_splits, shuffle=True)
    for ti, vi in splitter.split(dummy, y=cmb["topic"]):
        tra = cmb.iloc[ti]
        val = cmb.iloc[vi]
        break
else:
    splitter = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)
    for ti, vi in splitter.split(dummy, y=cmb["score"], groups=cmb["topic"]):
        tra = cmb.iloc[ti]
        val = cmb.iloc[vi]
        break
tra.info()

<class 'pandas.core.frame.DataFrame'>
Index: 24956 entries, 0 to 25995
Data columns (total 7 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   essay_id                   24956 non-null  object
 1   score                      24956 non-null  int8  
 2   ctq_Qwen2-1.5B-Instruct    24956 non-null  object
 3   ctq_3_Qwen2-1.5B-Instruct  24956 non-null  object
 4   topic                      24956 non-null  object
 5   full_text                  24956 non-null  object
 6   source                     24956 non-null  object
dtypes: int8(1), object(6)
memory usage: 1.4+ MB


In [7]:
val.info()

<class 'pandas.core.frame.DataFrame'>
Index: 1040 entries, 69 to 25897
Data columns (total 7 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   essay_id                   1040 non-null   object
 1   score                      1040 non-null   int8  
 2   ctq_Qwen2-1.5B-Instruct    1040 non-null   object
 3   ctq_3_Qwen2-1.5B-Instruct  1040 non-null   object
 4   topic                      1040 non-null   object
 5   full_text                  1040 non-null   object
 6   source                     1040 non-null   object
dtypes: int8(1), object(6)
memory usage: 57.9+ KB


In [8]:
tra.to_parquet(f"output/tra_{version}.parquet", index=False)
val.to_parquet(f"output/val_{version}.parquet", index=False)
assert tra.notna().all(axis=None)
assert val.notna().all(axis=None)

In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

Total time taken 0:00:00.705618
