In [1]:
import os
import json
import pathlib
from datetime import datetime
import optuna
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.metrics import roc_auc_score
import torch
from typing import List, Dict, Union, Tuple, NamedTuple
from tqdm import tqdm
import scml
from scml import pandasx as pdx
tim = scml.Timer()
tim.start()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [2]:
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
job_dir = f"models/xgb/{ts}"
pathlib.Path(job_dir).mkdir(parents=True, exist_ok=True)
num_boost_round: int = 100
lr: Tuple[float, float] = (1e-3, 1e-3)
feature_fraction: Tuple[float, float] = (1, 1)
min_data_in_leaf: Tuple[int, int] = (20, 20)
objective: str = "binary:logistic"
n_trials: int = 1
label = "generated"

In [3]:
df = pd.read_parquet("input/features.parquet")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 43529 entries, 0 to 43528
Columns: 29835 entries, essay_id to tf_Ġzygomatic
dtypes: float32(29820), int16(2), int32(5), int8(1), object(7)
memory usage: 4.8+ GB


In [4]:
features = []
prefixes = ["ch_", "ws_", "ts_", "va_", "tf_"]
for col in df.columns:
    for prefix in prefixes:
        if col.startswith(prefix):
            features.append(col)
features.sort()
print(f"{len(features)} features\n{features[:100]}")

29823 features
['ch_digit_frac', 'ch_len', 'ch_letter_frac', 'ch_punc_frac', 'ch_repeat_char_frac', 'ch_space_frac', 'ch_upper_frac', 'tf_0', 'tf_00', 'tf_000', 'tf_03', 'tf_1', 'tf_10', 'tf_11', 'tf_12', 'tf_13', 'tf_14', 'tf_15', 'tf_16', 'tf_17', 'tf_18', 'tf_19', 'tf_1990', 'tf_2', 'tf_20', 'tf_200', 'tf_21', 'tf_23', 'tf_24', 'tf_25', 'tf_27', 'tf_28', 'tf_3', 'tf_30', 'tf_31', 'tf_32', 'tf_33', 'tf_34', 'tf_38', 'tf_39', 'tf_4', 'tf_40', 'tf_41', 'tf_43', 'tf_45', 'tf_5', 'tf_50', 'tf_538', 'tf_58', 'tf_6', 'tf_60', 'tf_62', 'tf_7', 'tf_70', 'tf_74', 'tf_76', 'tf_79', 'tf_8', 'tf_87', 'tf_9', 'tf_90', 'tf_a', 'tf_aa', 'tf_aae', 'tf_aage', 'tf_aaion', 'tf_ab', 'tf_aban', 'tf_abe', 'tf_abel', 'tf_aber', 'tf_abet', 'tf_abeth', 'tf_abil', 'tf_abilites', 'tf_abilitie', 'tf_abilities', 'tf_ability', 'tf_abill', 'tf_abilty', 'tf_abitable', 'tf_abital', 'tf_abl', 'tf_able', 'tf_abled', 'tf_ables', 'tf_abling', 'tf_ablished', 'tf_ablities', 'tf_ablity', 'tf_ably', 'tf_about', 'tf_abra', '

In [5]:
#X_train, X_test, y_train, y_test = train_test_split(tra[features], tra[label], test_size=0.2)

tra = df[df["white_sim"]>=0.45]
val = df[df["white_sim"]<0.45]
t = len(tra)
v = len(val)
n = t+v
print(f"val%={v/n:.4f}, len(tra)={t:,}, len(val)={v:,}")
dtrain = xgb.DMatrix(tra[features], tra[label], enable_categorical=False)
dval = xgb.DMatrix(val[features], val[label], enable_categorical=False)
pdx.value_counts(val[label])

val%=0.0214, len(tra)=42,596, len(val)=933


Unnamed: 0_level_0,count,percent
generated,Unnamed: 1_level_1,Unnamed: 2_level_1
1,793,0.849946
0,140,0.150054


In [6]:
%%time
model = xgb.train(
   params={
       "objective": objective,
       "learning_rate": 5e-2,
       "min_child_weight": 20,
       "colsample_bytree": 0.5,
       "max_depth": 6,
   },
   dtrain=dtrain,
   num_boost_round=2000,
   evals=[(dtrain, "train"), (dval, "val")],
   verbose_eval=40,
   early_stopping_rounds=100,
)
print(f"best score {model.best_score:.5f} at iteration {model.best_iteration}")
model.save_model(f"{job_dir}/model.json")

[0]	train-logloss:0.58618	val-logloss:0.97015
[40]	train-logloss:0.14982	val-logloss:0.36515
[80]	train-logloss:0.07517	val-logloss:0.23699
[120]	train-logloss:0.04963	val-logloss:0.18333
[160]	train-logloss:0.03766	val-logloss:0.15579
[200]	train-logloss:0.03024	val-logloss:0.13848
[240]	train-logloss:0.02504	val-logloss:0.12711
[280]	train-logloss:0.02149	val-logloss:0.12077
[320]	train-logloss:0.01872	val-logloss:0.11470
[360]	train-logloss:0.01647	val-logloss:0.11067
[400]	train-logloss:0.01476	val-logloss:0.10653
[440]	train-logloss:0.01331	val-logloss:0.10451
[480]	train-logloss:0.01210	val-logloss:0.10301
[520]	train-logloss:0.01103	val-logloss:0.10151
[560]	train-logloss:0.01010	val-logloss:0.10100
[600]	train-logloss:0.00932	val-logloss:0.10048
[640]	train-logloss:0.00865	val-logloss:0.09999
[680]	train-logloss:0.00805	val-logloss:0.09904
[720]	train-logloss:0.00753	val-logloss:0.09878
[760]	train-logloss:0.00707	val-logloss:0.09817
[800]	train-logloss:0.00666	val-logloss:0.09

In [7]:
%%time
y_true = val[label].tolist()
y_pred = model.predict(data=dval, iteration_range=(0, model.best_iteration+1))
auc = roc_auc_score(y_true, y_pred, average="macro")
print(f"auc={auc:.4f}")
print(f"y_pred={y_pred.shape}\n{y_pred[:5]}")

auc=0.9926
y_pred=(933,)
[0.99992895 0.9997522  0.99393684 0.9301336  0.99846566]
CPU times: user 370 ms, sys: 462 ms, total: 831 ms
Wall time: 64.4 ms


In [8]:
%%time
scores = model.get_score(importance_type="gain")
assert len(scores)!=0
rows = []
for feature, score in scores.items():
    rows.append({'importance': score, 'feature': feature})
idf = pd.DataFrame.from_records(rows)
idf = idf.sort_values(["importance"], ascending=False, ignore_index=True)
fp = f"{job_dir}/importance.csv"
idf.to_csv(fp, index=True)
print(f"Saved {fp}")
idf.T.head()

Saved models/xgb/20240121_021130/importance.csv
CPU times: user 22.1 ms, sys: 46.6 ms, total: 68.7 ms
Wall time: 6 ms


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407
importance,661.597534,580.991089,520.642212,405.889587,380.618866,240.466797,239.409851,238.442108,218.278152,187.281799,184.298508,181.102798,176.987335,132.721313,129.132324,104.596504,101.635033,100.086754,93.471291,92.816116,89.992638,85.454712,83.987701,83.834717,83.344002,82.944618,81.014114,79.86113,76.984772,74.992149,74.731422,73.908203,71.399536,71.09333,70.566414,70.2481,68.868896,67.324066,65.838104,64.030273,62.71357,61.870018,57.530853,51.854958,51.35556,48.884918,48.155899,47.38916,46.347572,46.327637,44.562149,43.605053,41.729317,41.457748,40.443329,40.406376,40.144859,40.034851,39.036804,37.398949,36.848492,36.416351,36.176056,35.987251,35.579456,35.425541,35.103855,35.080502,34.496826,34.133518,33.192463,33.009266,32.06942,31.338507,31.083378,30.853495,30.797331,30.461609,30.381845,30.220469,30.090424,29.481979,29.011768,28.63522,28.425415,28.025711,27.071001,26.896776,26.888142,25.628754,25.48353,25.275902,24.018242,23.696852,23.588829,23.426626,22.660467,22.441732,21.596815,21.554689,21.009773,20.925755,20.504051,20.193283,19.988934,19.944071,19.927662,19.911245,19.811428,19.766273,19.749308,18.282394,17.880573,17.86974,17.673416,17.618467,17.211958,17.182722,17.172173,17.163725,16.976942,16.936079,16.326605,16.122314,16.105356,15.895791,15.562007,15.558497,15.534515,15.273863,15.072861,14.98022,14.947492,14.851101,14.493192,14.252938,14.102311,13.970284,13.9264,13.918358,13.861259,13.493904,13.470053,13.438181,13.337593,13.245678,13.207893,13.118791,12.846616,12.750061,12.54248,12.52393,12.301524,11.99415,11.922365,11.901738,11.842185,11.512276,11.431728,11.384054,11.293309,11.093381,11.012393,10.879386,10.878849,10.830639,10.635188,10.585239,10.170173,10.123187,10.088881,9.934095,9.919516,9.795685,9.609342,9.30803,9.182851,9.158499,8.998846,8.933435,8.846701,8.821812,8.712491,8.708257,8.676237,8.652871,8.645793,8.37716,8.307696,8.267982,8.182672,8.168635,8.04805,8.044981,7.892286,7.854073,7.656933,7.611191,7.552275,7.481313,7.388374,7.209621,7.118626,7.115748,7.05514,6.978959,6.890601,6.886446,6.645034,6.530092,6.507767,6.349953,6.294907,6.280403,6.208178,6.202746,6.104156,5.964964,5.837977,5.818327,5.815298,5.709992,5.631436,5.620426,5.614038,5.596033,5.529945,5.419085,5.402081,5.380029,5.340683,5.328637,5.256274,5.252212,5.23568,5.208162,5.194965,5.065253,4.885905,4.847053,4.773505,4.769887,4.724706,4.714436,4.70049,4.694871,4.612142,4.607749,4.582759,4.522965,4.480743,4.444437,4.342023,4.305319,4.24401,4.198876,4.120795,4.090305,4.023343,3.994397,3.959758,3.944437,3.885608,3.847145,3.78024,3.763478,3.714287,3.700971,3.526972,3.512095,3.493603,3.466834,3.453146,3.450044,3.433106,3.419986,3.397957,3.395497,3.367174,3.367153,3.365255,3.361254,3.312541,3.234521,3.197531,3.197361,3.196799,3.183557,3.144096,3.066222,2.936248,2.93553,2.911227,2.909382,2.902889,2.888231,2.874432,2.86217,2.849054,2.814121,2.794135,2.793975,2.791943,2.78416,2.773086,2.738538,2.731679,2.711715,2.682313,2.680587,2.666754,2.652306,2.65118,2.642481,2.625511,2.625185,2.616417,2.590504,2.590005,2.567017,2.564539,2.542941,2.501905,2.483153,2.456413,2.449562,2.417475,2.385181,2.373887,2.36267,2.347529,2.33098,2.329677,2.327633,2.265077,2.251067,2.200758,2.193883,2.19388,2.189415,2.165207,2.122313,2.11352,2.111507,2.108935,2.073899,2.062941,2.042671,2.038351,2.037022,2.026715,2.002,1.994819,1.905548,1.902783,1.895117,1.878437,1.869759,1.838896,1.81638,1.790969,1.787349,1.784771,1.781002,1.776031,1.745369,1.716723,1.716033,1.697812,1.693832,1.69218,1.690742,1.672335,1.666259,1.65098,1.637237,1.627404,1.561339,1.559408,1.554674,1.518231,1.497944,1.483841,1.480845,1.459464,1.442078,1.425265,1.293403,1.277447,1.265386,1.251879,1.138635,1.136488,1.079514,1.023715,1.015635,1.003615,0.979889,0.976981,0.897438,0.889641,0.864037,0.756007,0.705514,0.668125,0.633369,0.545553,0.545249
feature,ts_polysyllable_frac,tf_Ġhey,ts_syllables_per_word,ch_space_frac,tf_Ġsuper,tf_Ġdr,tf_Ġgoals,tf_Ġbecause,tf_Ġhuang,tf_Ġachieve,tf_Ġadditionally,tf_Ġessay,ts_smog_index,tf_Ġcultures,tf_Ġnt,tf_Ġconclusion,tf_Ġgrader,tf_Ġsustainable,tf_Ġessential,tf_th,tf_Ġvery,tf_Ġemerson,tf_Ġseas,tf_Ġwould,tf_Ġdear,tf_Ġimportant,ws_sent_len_std,ts_monosyllable_frac,tf_Ġelectors,tf_Ġsuccess,tf_Ġquality,tf_Ġattempt,tf_Ġcowboy,ch_punc_frac,ts_lexicon_count,tf_Ġadventure,tf_Ġtotally,ch_digit_frac,tf_Ġfirstly,tf_Ġconfused,tf_Ġexperiences,tf_Ġensures,tf_Ġemotion,tf_Ġskills,tf_Ġimportance,tf_Ġand,tf_Ġbeyond,tf_Ġelection,tf_Ġultimately,tf_Ġpursue,tf_Ġthen,tf_Ġaddress,tf_Ġsenator,tf_Ġcomputer,tf_Ġpercent,tf_Ġwriting,tf_Ġchina,tf_Ġconcerns,tf_Ġachieving,tf_Ġoverall,tf_Ġpotential,ch_letter_frac,tf_Ġstuff,tf_Ġcar,tf_Ġeurope,tf_Ġlike,tf_Ġchallenges,tf_Ġ10,tf_Ġperspectives,ts_syllable_count,tf_Ġparagraph,tf_Ġvenus,ts_coleman_liau_index,tf_Ġalthough,tf_Ġsignificant,tf_Ġprobably,tf_Ġsmaller,tf_Ġhumans,tf_Ġlead,ts_difficult_words,tf_Ġplus,ch_len,ws_sent_len_delta_mean,ws_sent_len_delta_std,tf_Ġourselves,tf_Ġsecondly,tf_Ġetc,tf_Ġfacial,tf_Ġnasa,tf_Ġanimals,tf_Ġextracurricular,tf_Ġthe,tf_Ġprovide,ts_sentence_count,tf_Ġpossibly,tf_Ġleast,tf_Ġ8,tf_Ġseagoing,tf_Ġcool,tf_Ġmy,tf_Ġschool,tf_Ġapproach,tf_Ġhuman,tf_Ġinformed,tf_Ġensure,tf_Ġargue,tf_Ġcommunity,tf_Ġstates,tf_Ġus,tf_Ġactivity,tf_Ġme,tf_Ġunited,tf_Ġtext,tf_Ġprincipal,tf_Ġbalance,tf_Ġpresident,tf_Ġeveryday,tf_Ġresources,tf_Ġunique,tf_Ġdriving,tf_Ġconsider,tf_Ġtrue,tf_Ġcars,ts_spache_readability,tf_Ġmany,tf_Ġservice,tf_Ġyou,tf_Ġdo,tf_Ġbeneficial,tf_Ġgrade,tf_Ġreduce,tf_Ġshould,tf_Ġstudent,tf_Ġearth,ts_gunning_fog,tf_Ġallows,tf_Ġstudents,ws_sent_len_mean,va_valence_mean,tf_Ġmean,ts_flesch_reading_ease,tf_Ġreducing,tf_Ġdifficult,tf_Ġdesigned,tf_Ġcomputers,tf_Ġalmost,tf_Ġadvice,tf_Ġif,tf_Ġpublic,tf_Ġday,tf_Ġsmog,tf_Ġhand,ts_dale_chall_readability_score,tf_Ġfair,tf_Ġwe,ts_syllables_per_sent,tf_Ġsupport,tf_Ġsport,tf_Ġtechnology,tf_Ġhealth,tf_Ġgo,tf_Ġdrive,tf_Ġremember,va_arousal_mean,tf_Ġfurthermore,tf_Ġparticipate,tf_Ġopportunities,ch_upper_frac,tf_Ġwhile,tf_Ġget,tf_Ġimpact,tf_Ġfocus,tf_Ġcreate,tf_Ġso,tf_Ġpoint,tf_Ġrisks,tf_Ġlet,tf_Ġi,tf_Ġthough,tf_Ġhowever,tf_Ġwill,tf_Ġmost,tf_Ġmight,ts_mcalpine_eflaw,tf_Ġstate,tf_Ġfinally,tf_Ġexperience,tf_Ġsincerely,tf_Ġwas,tf_Ġcould,va_dominance_mean,tf_Ġlearn,tf_Ġam,tf_Ġlearning,tf_Ġinterests,tf_Ġit,tf_Ġboth,tf_Ġshow,tf_Ġmuch,tf_Ġkids,tf_Ġcell,tf_Ġstudying,tf_Ġlearned,tf_Ġpeople,tf_Ġthey,tf_Ġoptions,tf_Ġunderstand,tf_Ġexplore,tf_Ġoften,tf_Ġroad,tf_Ġclear,tf_Ġexpress,tf_Ġmistakes,tf_Ġmeans,tf_Ġthank,tf_Ġwho,tf_Ġadvantages,tf_Ġmatter,tf_Ġwhat,tf_Ġplanet,tf_Ġyour,tf_Ġmay,tf_Ġfeel,tf_Ġrequired,tf_Ġits,tf_Ġown,tf_Ġbenefits,tf_Ġbelieve,tf_Ġprojects,tf_Ġteacher,tf_Ġ3,tf_Ġis,tf_Ġhard,ts_words_per_sent,tf_Ġschools,tf_Ġbut,tf_Ġnot,tf_Ġagree,tf_Ġall,tf_Ġat,tf_Ġwhether,tf_Ġcan,tf_Ġmake,tf_Ġno,tf_Ġspend,tf_Ġwhy,tf_Ġfirst,tf_Ġexample,tf_Ġgoing,tf_Ġname,tf_Ġable,tf_Ġabout,tf_Ġto,tf_Ġhelping,tf_Ġgiven,tf_Ġfor,tf_Ġactivities,tf_Ġsports,tf_Ġanother,tf_Ġlives,tf_Ġeverything,tf_Ġimagine,ts_flesch_kincaid_grade,tf_Ġsummer,tf_Ġthink,tf_Ġsay,ts_linsear_write_formula,tf_Ġour,tf_Ġbenefit,va_valence_std,tf_Ġstart,tf_Ġfact,va_dominance_std,tf_Ġan,tf_Ġbad,tf_Ġask,va_arousal_std,tf_Ġkeep,tf_Ġparents,tf_Ġgood,tf_Ġlast,tf_Ġstudies,tf_Ġtheir,tf_Ġworth,tf_Ġafter,tf_Ġgives,tf_Ġreason,tf_Ġnow,tf_Ġthese,tf_Ġtherefore,tf_Ġinstead,tf_Ġone,tf_Ġaround,tf_Ġare,tf_Ġtoo,tf_Ġhow,tf_Ġaverage,tf_Ġknow,tf_Ġplace,tf_Ġfun,tf_Ġphone,tf_Ġwere,tf_Ġcause,tf_Ġlife,tf_Ġbeing,tf_Ġhome,tf_Ġa,tf_Ġthan,tf_Ġbe,tf_Ġrather,tf_Ġworld,tf_Ġhaving,ts_automated_readability_index,tf_Ġin,tf_Ġsense,tf_Ġdown,tf_Ġthis,ch_repeat_char_frac,tf_Ġthat,tf_Ġsince,tf_Ġothers,tf_Ġhis,tf_Ġstay,tf_Ġwith,tf_Ġout,tf_Ġbecome,tf_Ġdoing,tf_Ġallow,tf_Ġreally,tf_Ġwhen,tf_Ġtry,tf_Ġhigh,tf_Ġoff,tf_Ġteach,tf_Ġs,tf_Ġjust,tf_Ġthing,tf_Ġtime,tf_Ġway,tf_Ġstill,tf_Ġattention,tf_Ġwhich,tf_Ġtake,tf_Ġput,tf_Ġup,tf_Ġany,tf_Ġthere,tf_Ġwithout,tf_Ġother,tf_Ġhave,tf_Ġmore,tf_Ġof,tf_Ġeven,tf_Ġor,tf_Ġbeen,tf_Ġtaking,tf_Ġsee,tf_Ġimprove,tf_Ġbetter,tf_Ġdone,tf_Ġsame,tf_Ġon,tf_Ġthem,tf_Ġuse,tf_Ġtimes,tf_Ġsuch,tf_Ġfrom,tf_Ġdifferent,tf_Ġas,tf_Ġwant,tf_Ġgive,tf_Ġidea,tf_Ġhelp,tf_Ġusing,tf_Ġnew,tf_Ġmakes,tf_Ġonly,tf_Ġperson,tf_Ġevery,tf_Ġmade,tf_Ġalso,tf_Ġthings,tf_Ġeach,tf_Ġsome,tf_Ġjob,tf_Ġless,tf_Ġlot,tf_Ġsaid,tf_Ġfriends,tf_Ġactually,tf_Ġduring,tf_Ġgreat,tf_Ġclass,tf_Ġmaybe,tf_Ġthose,tf_Ġnever,tf_Ġwhere,tf_Ġsomeone,tf_Ġsomething,tf_Ġby,tf_Ġaway,tf_Ġmaking,tf_Ġinto,tf_Ġthrough,tf_Ġtell,tf_Ġgas,tf_Ġlikely,tf_Ġfuture


In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

Total time taken 0:17:24.858528
