In [1]:
import os
import json
import gc
from datetime import datetime
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from tqdm import tqdm
from typing import Dict, List, Set, Tuple, NamedTuple, Callable
import scml
from scml import pandasx as pdx

In [2]:
version = "06"
files = [
    Path("input/comp_02.parquet"),
    Path("input/persuade_02.parquet"),
]
n_splits = 15  # 15 topics exist in corpus
persuade_topic_classification = False
if persuade_topic_classification:
    n_splits = 25

In [3]:
tim = scml.Timer()
tim.start()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()
info = np.iinfo(np.int16)
print(f"int16, min={info.min}, max={info.max}")

int16, min=-32768, max=32767


In [4]:
cols = [
    "essay_id", 
    "score", 
    #"ctq_Qwen2-1.5B-Instruct", 
    "ctq_3_Qwen2-1.5B-Instruct", 
    "topic", 
    "full_text"
]
cmb = None
for filepath in files:
    df = pd.read_parquet(filepath)
    df = df[cols]
    df["source"] = filepath.stem
    if cmb is None:
        cmb = df
    else:
        cmb = pd.concat([cmb, df], ignore_index=True)
df = cmb
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 43303 entries, 0 to 43302
Data columns (total 6 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   essay_id                   43303 non-null  object
 1   score                      43303 non-null  int8  
 2   ctq_3_Qwen2-1.5B-Instruct  43303 non-null  object
 3   topic                      43303 non-null  object
 4   full_text                  43303 non-null  object
 5   source                     43303 non-null  object
dtypes: int8(1), object(5)
memory usage: 1.7+ MB


In [5]:
more = len(df)
df = df.drop_duplicates(["full_text"], keep="first", ignore_index=True)
print(f"{more - len(df):,} rows dropped: duplicate text")

12,875 rows dropped: duplicate text


In [6]:
pdx.value_counts(df["topic"])

Unnamed: 0_level_0,count,percent
topic,Unnamed: 1_level_1,Unnamed: 2_level_1
driverless cars,3499,0.114993
facial action coding system,3043,0.100007
exploring venus,3015,0.099086
distance learning,2157,0.070889
face mars,2094,0.068818
electoral college work,2046,0.067241
car free cities,1963,0.064513
summer projects,1750,0.057513
mandatory extracurricular activities,1671,0.054917
cell phones school,1656,0.054424


# Train/Test Split

In [7]:
dummy = np.zeros(len(df))
if persuade_topic_classification:
    splitter = StratifiedKFold(n_splits=n_splits, shuffle=True)
    for ti, vi in splitter.split(dummy, y=df["topic"]):
        tra = df.iloc[ti]
        val = df.iloc[vi]
        break
else:
    splitter = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)
    for ti, vi in splitter.split(dummy, y=df["score"], groups=df["topic"]):
        tra = df.iloc[ti]
        val = df.iloc[vi]
        break
tra.info()

<class 'pandas.core.frame.DataFrame'>
Index: 28334 entries, 0 to 30427
Data columns (total 6 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   essay_id                   28334 non-null  object
 1   score                      28334 non-null  int8  
 2   ctq_3_Qwen2-1.5B-Instruct  28334 non-null  object
 3   topic                      28334 non-null  object
 4   full_text                  28334 non-null  object
 5   source                     28334 non-null  object
dtypes: int8(1), object(5)
memory usage: 1.3+ MB


In [8]:
val.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2094 entries, 1 to 17287
Data columns (total 6 columns):
 #   Column                     Non-Null Count  Dtype 
---  ------                     --------------  ----- 
 0   essay_id                   2094 non-null   object
 1   score                      2094 non-null   int8  
 2   ctq_3_Qwen2-1.5B-Instruct  2094 non-null   object
 3   topic                      2094 non-null   object
 4   full_text                  2094 non-null   object
 5   source                     2094 non-null   object
dtypes: int8(1), object(5)
memory usage: 100.2+ KB


In [9]:
tra.to_parquet(f"output/tra_{version}.parquet", index=False)
val.to_parquet(f"output/val_{version}.parquet", index=False)
assert tra.notna().all(axis=None)
assert val.notna().all(axis=None)

In [10]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

Total time taken 0:00:00.861992
