## Import

In [3]:
import pandas as pd
import numpy as np
import os
import random
import polars as pl
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import duckdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

## Setting

In [4]:
CFG = {
    'BATCH_SIZE': 4096,
    'EPOCHS': 10,
    'LEARNING_RATE': 1e-3,
    'SEED' : 42
}
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG['SEED']) # Seed 고정

## Data Load

In [6]:
train = pl.read_parquet('../data/train.parquet')
train.head()

gender,age_group,inventory_id,day_of_week,hour,seq,l_feat_1,l_feat_2,l_feat_3,l_feat_4,l_feat_5,l_feat_6,l_feat_7,l_feat_8,l_feat_9,l_feat_10,l_feat_11,l_feat_12,l_feat_13,l_feat_14,l_feat_15,l_feat_16,l_feat_17,l_feat_18,l_feat_19,l_feat_20,l_feat_21,l_feat_22,l_feat_23,l_feat_24,l_feat_25,l_feat_26,l_feat_27,feat_e_1,feat_e_2,feat_e_3,feat_e_4,…,history_a_2,history_a_3,history_a_4,history_a_5,history_a_6,history_a_7,history_b_1,history_b_2,history_b_3,history_b_4,history_b_5,history_b_6,history_b_7,history_b_8,history_b_9,history_b_10,history_b_11,history_b_12,history_b_13,history_b_14,history_b_15,history_b_16,history_b_17,history_b_18,history_b_19,history_b_20,history_b_21,history_b_22,history_b_23,history_b_24,history_b_25,history_b_26,history_b_27,history_b_28,history_b_29,history_b_30,clicked
str,str,str,str,str,str,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i32
"""1.0""","""7.0""","""36""","""5""","""13""","""9,18,269,516,57,97,527,74,317,…",1.0,2.0,1.0,23.0,1.0,1.0,193.0,2.0,50.0,118.0,743.0,2877.0,2.0,1591.0,1058.0,2.0,50.0,1.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,5.0,2.0,65.0,-4230.666504,23.863636,-0.05,…,-0.055556,0.02439,-326.857147,-0.014493,-183.285721,-13.596154,0.115821,0.138626,0.047507,0.050622,0.026479,0.001558,0.024922,0.051401,0.004673,0.021028,0.072428,0.007009,0.028816,0.05841,0.000779,0.072428,0.016355,0.011682,0.010124,0.002336,0.008567,0.070092,0.070092,0.011682,0.004673,0.087226,0.049843,0.015576,0.040498,0.051401,0
"""1.0""","""7.0""","""2""","""5""","""08""","""9,144,269,57,516,97,527,74,315…",2.0,2.0,3.0,17.0,193.0,116.0,164.0,2.0,14.0,109.0,674.0,218.0,2.0,122.0,751.0,1.0,14.0,1.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,16.0,1.0,65.0,-1346.648193,4.545455,-0.05,…,0.0,0.0,-382.285706,0.0,-176.0,-11.442307,0.068794,0.072179,0.049471,0.052715,0.027574,0.001622,0.025952,0.053526,0.004866,0.021897,0.075423,0.007299,0.030007,0.060825,0.000811,0.075423,0.017031,0.012165,0.010543,0.002433,0.008921,0.07299,0.07299,0.012165,0.004866,0.045416,0.051904,0.01622,0.042172,0.026763,0
"""1.0""","""7.0""","""36""","""5""","""11""","""269,516,57,97,165,527,74,77,31…",1.0,2.0,1.0,7.0,675.0,85.0,227.0,2.0,362.0,212.0,1029.0,3916.0,1.0,2924.0,2304.0,2.0,362.0,3.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,16.0,1.0,64.638885,-3195.388916,22.727272,-0.05,…,-0.111111,0.097561,-409.0,-0.014493,-224.714279,-13.942307,0.112947,0.169634,0.038753,0.041295,0.0216,0.001271,0.02033,0.04193,0.003812,0.017153,0.059083,0.005718,0.023506,0.047647,0.000635,0.059083,0.026683,0.00953,0.008259,0.001906,0.006988,0.057177,0.057177,0.00953,0.003812,0.035577,0.081318,0.012706,0.033036,0.062898,0
"""1.0""","""8.0""","""37""","""5""","""11""","""269,57,516,21,214,269,561,214,…",2.0,2.0,2.0,7.0,294.0,442.0,130.0,2.0,163.0,179.0,102.0,789.0,2.0,2169.0,439.0,1.0,163.0,1.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,4.0,1.0,65.0,-4029.962891,3.863636,-0.05,…,-0.055556,0.02439,-274.428558,-0.014493,-127.85714,-9.846154,0.159843,0.198657,0.068082,0.072546,0.037947,0.002232,0.035715,0.073663,0.006697,0.030135,0.103797,0.010045,0.041296,0.083707,0.002232,0.103797,0.023438,0.016741,0.014509,0.003348,0.012277,0.100449,0.100449,0.016741,0.006697,0.062502,0.07143,0.022322,0.058037,0.073659,0
"""2.0""","""7.0""","""37""","""5""","""07""","""144,269,57,516,35,479,57,516,5…",2.0,2.0,3.0,24.0,497.0,435.0,171.0,2.0,193.0,131.0,690.0,110.0,1.0,2084.0,106.0,1.0,193.0,1.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,15.0,4.0,65.0,-2106.407471,8.522727,-0.05,…,0.0,0.0,-407.571442,0.0,-199.142853,-14.019231,0.056166,0.063795,0.043725,0.046592,0.024371,0.001434,0.022938,0.047309,0.004301,0.019354,0.066662,0.006451,0.026522,0.05376,0.000717,0.066662,0.015053,0.010752,0.009318,0.00215,0.007885,0.064512,0.064512,0.010752,0.004301,0.040141,0.045875,0.014336,0.037274,0.023654,0


In [7]:
test = pl.read_parquet('../data/test.parquet')
test.head()

ID,gender,age_group,inventory_id,day_of_week,hour,seq,l_feat_1,l_feat_2,l_feat_3,l_feat_4,l_feat_5,l_feat_6,l_feat_7,l_feat_8,l_feat_9,l_feat_10,l_feat_11,l_feat_12,l_feat_13,l_feat_14,l_feat_15,l_feat_16,l_feat_17,l_feat_18,l_feat_19,l_feat_20,l_feat_21,l_feat_22,l_feat_23,l_feat_24,l_feat_25,l_feat_26,l_feat_27,feat_e_1,feat_e_2,feat_e_3,…,history_a_1,history_a_2,history_a_3,history_a_4,history_a_5,history_a_6,history_a_7,history_b_1,history_b_2,history_b_3,history_b_4,history_b_5,history_b_6,history_b_7,history_b_8,history_b_9,history_b_10,history_b_11,history_b_12,history_b_13,history_b_14,history_b_15,history_b_16,history_b_17,history_b_18,history_b_19,history_b_20,history_b_21,history_b_22,history_b_23,history_b_24,history_b_25,history_b_26,history_b_27,history_b_28,history_b_29,history_b_30
str,str,str,str,str,str,str,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""TEST_0000000""","""2.0""","""6.0""","""46""","""7""","""13""","""321,269,57,516,479,516,57,479,…",2.0,2.0,2.0,19.0,1047.0,161.0,265.0,2.0,232.0,96.0,1513.0,4944.0,2.0,2065.0,1713.0,1.0,232.0,1.0,2.0,2.0,2.0,2.0,1.0,1.0,1674.0,16.0,1.0,65.0,-2554.314697,15.909091,…,0.060337,-0.166667,0.060976,-355.0,-0.043478,-180.428574,-13.326923,0.236312,0.281649,0.048257,0.051422,0.026897,0.001582,0.025315,0.052213,0.004747,0.02136,0.073572,0.014241,0.029271,0.059333,0.000791,0.073572,0.016613,0.011866,0.010284,0.00712,0.008702,0.071199,0.071199,0.011866,0.004747,0.044302,0.05063,0.015822,0.041137,0.104432
"""TEST_0000001""","""2.0""","""8.0""","""29""","""7""","""21""","""57,35,479,57,463,212,193,151,4…",2.0,2.0,2.0,7.0,1024.0,297.0,47.0,2.0,220.0,57.0,1422.0,4758.0,1.0,1082.0,675.0,1.0,220.0,1.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,15.0,1.0,65.0,-684.240723,6.818182,…,0.024875,0.0,0.0,-143.428574,0.0,-63.857143,-3.942308,0.19903,0.198657,0.136158,0.145086,0.075891,0.004464,0.071427,0.147319,0.013393,0.060267,0.207585,0.020089,0.082588,0.167407,0.002232,0.207585,0.046874,0.033482,0.029017,0.006696,0.024553,0.200889,0.200889,0.033482,0.013393,0.124998,0.142854,0.044642,0.116069,0.073659
"""TEST_0000002""","""1.0""","""6.0""","""37""","""7""","""19""","""57,516,97,74,527,77,318,315,31…",2.0,2.0,3.0,7.0,562.0,107.0,63.0,2.0,118.0,157.0,90.0,2100.0,2.0,1337.0,1329.0,1.0,118.0,2.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,11.0,1.0,64.638885,-2239.722168,4.318182,…,0.031765,0.0,0.0,-112.285713,0.0,-72.14286,-5.173077,0.151852,0.175891,0.120554,0.128459,0.067194,0.003953,0.063242,0.130436,0.011858,0.05336,0.183796,0.017787,0.073123,0.148223,0.001976,0.183796,0.041502,0.029645,0.025692,0.005929,0.021739,0.177867,0.177867,0.029645,0.011858,0.110673,0.126483,0.039526,0.102768,0.065218
"""TEST_0000003""","""2.0""","""7.0""","""41""","""7""","""09""","""144,321,57,479,57,479,35,57,51…",2.0,2.0,2.0,7.0,444.0,56.0,212.0,2.0,171.0,36.0,481.0,2377.0,1.0,102.0,1157.0,1.0,171.0,4.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,1.0,5.0,64.638885,-1097.240723,11.363636,…,0.02606,-0.055556,0.02439,-411.0,-0.014493,-237.428574,-17.76923,0.08865,0.107031,0.036679,0.039085,0.020444,0.001203,0.019242,0.039686,0.003608,0.016235,0.055921,0.005412,0.022248,0.045098,0.000601,0.055921,0.012627,0.009019,0.007817,0.001804,0.006614,0.108234,0.054117,0.009019,0.003608,0.033673,0.038483,0.012026,0.031268,0.039686
"""TEST_0000004""","""1.0""","""8.0""","""2""","""7""","""18""","""269,57,516,342,516,403,173,457…",2.0,2.0,3.0,8.0,709.0,738.0,22.0,2.0,338.0,100.0,1153.0,3886.0,2.0,2930.0,2381.0,1.0,338.0,3.0,2.0,2.0,2.0,2.0,1.0,1.0,1129.0,8.0,1.0,65.0,-3665.592529,2.272727,…,0.091343,-0.055556,0.146341,-508.142853,0.0,-240.428574,-16.73077,0.047072,0.105696,0.036222,0.038597,0.020189,0.001188,0.019002,0.039191,0.003563,0.016033,0.055223,0.010688,0.021971,0.044535,0.000594,0.055223,0.01247,0.008907,0.007719,0.001781,0.006532,0.053442,0.053442,0.008907,0.003563,0.033253,0.038003,0.011876,0.030878,0.039191


In [8]:
print(train.head())         # 상위 5개 행 출력
print(train.columns)        # 컬럼명 출력
print(train.shape)          # (행, 열) 개수


shape: (5, 119)
┌────────┬───────────┬────────────┬────────────┬───┬────────────┬────────────┬───────────┬─────────┐
│ gender ┆ age_group ┆ inventory_ ┆ day_of_wee ┆ … ┆ history_b_ ┆ history_b_ ┆ history_b ┆ clicked │
│ ---    ┆ ---       ┆ id         ┆ k          ┆   ┆ 28         ┆ 29         ┆ _30       ┆ ---     │
│ str    ┆ str       ┆ ---        ┆ ---        ┆   ┆ ---        ┆ ---        ┆ ---       ┆ i32     │
│        ┆           ┆ str        ┆ str        ┆   ┆ f32        ┆ f32        ┆ f32       ┆         │
╞════════╪═══════════╪════════════╪════════════╪═══╪════════════╪════════════╪═══════════╪═════════╡
│ 1.0    ┆ 7.0       ┆ 36         ┆ 5          ┆ … ┆ 0.015576   ┆ 0.040498   ┆ 0.051401  ┆ 0       │
│ 1.0    ┆ 7.0       ┆ 2          ┆ 5          ┆ … ┆ 0.01622    ┆ 0.042172   ┆ 0.026763  ┆ 0       │
│ 1.0    ┆ 7.0       ┆ 36         ┆ 5          ┆ … ┆ 0.012706   ┆ 0.033036   ┆ 0.062898  ┆ 0       │
│ 1.0    ┆ 8.0       ┆ 37         ┆ 5          ┆ … ┆ 0.022322   ┆ 0.058037 

In [9]:
sample_train = train.sample(n=10000,seed=42)
print(sample_train.describe())

shape: (9, 120)
┌────────────┬────────┬───────────┬────────────┬───┬────────────┬───────────┬───────────┬──────────┐
│ statistic  ┆ gender ┆ age_group ┆ inventory_ ┆ … ┆ history_b_ ┆ history_b ┆ history_b ┆ clicked  │
│ ---        ┆ ---    ┆ ---       ┆ id         ┆   ┆ 28         ┆ _29       ┆ _30       ┆ ---      │
│ str        ┆ str    ┆ str       ┆ ---        ┆   ┆ ---        ┆ ---       ┆ ---       ┆ f64      │
│            ┆        ┆           ┆ str        ┆   ┆ f64        ┆ f64       ┆ f64       ┆          │
╞════════════╪════════╪═══════════╪════════════╪═══╪════════════╪═══════════╪═══════════╪══════════╡
│ count      ┆ 9986   ┆ 9986      ┆ 10000      ┆ … ┆ 9986.0     ┆ 9986.0    ┆ 9986.0    ┆ 10000.0  │
│ null_count ┆ 14     ┆ 14        ┆ 0          ┆ … ┆ 14.0       ┆ 14.0      ┆ 14.0      ┆ 0.0      │
│ mean       ┆ null   ┆ null      ┆ null       ┆ … ┆ 0.085875   ┆ 0.220603  ┆ 0.194192  ┆ 0.0204   │
│ std        ┆ null   ┆ null      ┆ null       ┆ … ┆ 0.35605    ┆ 0.925508 

In [10]:
print(train.null_count())
print(train.n_unique)

shape: (1, 119)
┌────────┬───────────┬────────────┬────────────┬───┬────────────┬────────────┬───────────┬─────────┐
│ gender ┆ age_group ┆ inventory_ ┆ day_of_wee ┆ … ┆ history_b_ ┆ history_b_ ┆ history_b ┆ clicked │
│ ---    ┆ ---       ┆ id         ┆ k          ┆   ┆ 28         ┆ 29         ┆ _30       ┆ ---     │
│ u32    ┆ u32       ┆ ---        ┆ ---        ┆   ┆ ---        ┆ ---        ┆ ---       ┆ u32     │
│        ┆           ┆ u32        ┆ u32        ┆   ┆ u32        ┆ u32        ┆ u32       ┆         │
╞════════╪═══════════╪════════════╪════════════╪═══╪════════════╪════════════╪═══════════╪═════════╡
│ 17208  ┆ 17208     ┆ 0          ┆ 0          ┆ … ┆ 17208      ┆ 17208      ┆ 17208     ┆ 0       │
└────────┴───────────┴────────────┴────────────┴───┴────────────┴────────────┴───────────┴─────────┘
<bound method DataFrame.n_unique of shape: (10_704_179, 119)
┌────────┬───────────┬────────────┬────────────┬───┬────────────┬────────────┬───────────┬─────────┐
│ gender ┆ age

In [16]:
print(train['seq'])

shape: (10_704_179,)
Series: 'seq' [str]
[
	"9,18,269,516,57,97,527,74,317,…
	"9,144,269,57,516,97,527,74,315…
	"269,516,57,97,165,527,74,77,31…
	"269,57,516,21,214,269,561,214,…
	"144,269,57,516,35,479,57,516,5…
	…
	"9,144,269,57,516,417,227,27,22…
	"9,57,516,97,74,527,318,463,212…
	"9,516,57,195,27,516,173,457,40…
	"9,57,516,97,74,527,318,77,317,…
	"9,57,74,317,269,479,311,35,57,…
]


## Data Column Setting

## Define Custom Dataset

## Define Model Architecture

## Train / Validation

## Run!!

## Submission