In [1]:
import numpy as np
import pandas as pd
import os
import random

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

from sklearn.preprocessing import StandardScaler

In [2]:
target_category = [38, 110, 113, 114, 134, 171, 172, 173, 376, 435, 467, 537, 539, 629, 768]
target_category_str = [str(col) for col in target_category]

DEVICE= 'cpu'
SEED = 42
EPOCHS = 10
BATCH_SIZE = 256
LEARNING_RATE = 1e-32
WEIGHT_DECAY = 1e-5
EARLY_STOPPING = 2
# NUM_TARGETS = len(target_category_str)

In [75]:
input_path = '../datasets/'
train_df = pd.read_csv('../output/train_df.csv')

train_Y = pd.read_csv('../output/train_Y.csv')

In [4]:
train_df.head()

Unnamed: 0,session_id,user_id_x,date,hour,register_number,time_elapsed,month,day,weekday,jp_holiday,...,537_given,539_given,629_given,768_given,child_items,alone_items,cook_items,user_id_y,given_buy_num,avg_qoupon
0,105,CN9sWHXp6RdCuyFkW5aemG,2019-02-14,9,1005,152.0,2,14,3,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,CN9sWHXp6RdCuyFkW5aemG,3.0,
1,106,Wi5hmLRCmUPXMRheu354dd,2019-02-14,9,1010,147.0,2,14,3,0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,Wi5hmLRCmUPXMRheu354dd,2.0,
2,107,kTFrFDLeaaggCoubWZJHpg,2019-02-14,9,1010,177.0,2,14,3,0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,kTFrFDLeaaggCoubWZJHpg,1.0,
3,108,exwdBc8tNJYAjhc4Gd6qtj,2019-02-14,9,1011,247.0,2,14,3,0,...,0.0,0.0,0.0,0.0,0.0,1.0,2.0,exwdBc8tNJYAjhc4Gd6qtj,4.0,1.0
4,109,XUeiScqGsozKQFxcd3RDsD,2019-02-14,9,1013,147.0,2,14,3,0,...,0.0,0.0,0.0,0.0,0.0,0.0,2.0,XUeiScqGsozKQFxcd3RDsD,2.0,1.0


In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [76]:
feature_cols = ['age', 'gender', 'hour', 'weekday', 'jp_holiday', 'tgif', 'num_visit', 'month', 'day', 'register_number', 'max_time_avg', 'time_elapsed', 'hanakin', '38_avg', '110_avg', '113_avg', '114_avg', '134_avg', '171_avg', '172_avg', '173_avg', '376_avg', '435_avg', '467_avg', '537_avg', '539_avg', '629_avg', '768_avg', '38_given', '110_given', '113_given', '114_given', '134_given', '171_given', '172_given', '173_given', '376_given', '435_given', '467_given', '537_given', '539_given', '629_given', '768_given', '38_price_avg', '110_price_avg', '113_price_avg', '114_price_avg', '134_price_avg', '171_price_avg', '172_price_avg', '173_price_avg', '376_price_avg', '435_price_avg', '467_price_avg', '537_price_avg', '539_price_avg', '629_price_avg', '768_price_avg', 'child_items_sum', 'child_items_avg', 'child_items', '1_depart_avg', '2_depart_avg', '3_depart_avg', '4_depart_avg', '5_depart_avg', '7_depart_avg', '9_depart_avg', '10_depart_avg', '13_depart_avg', '14_depart_avg', '15_depart_avg', '16_depart_avg', '18_depart_avg', '19_depart_avg', '20_depart_avg', '21_depart_avg', '22_depart_avg', '23_depart_avg', '24_depart_avg', '25_depart_avg', '26_depart_avg', '27_depart_avg', '28_depart_avg', '29_depart_avg', '30_depart_avg', '32_depart_avg', '33_depart_avg', '34_depart_avg', '35_depart_avg', '36_depart_avg', '37_depart_avg', '38_depart_avg', '39_depart_avg', '40_depart_avg', '41_depart_avg', '46_depart_avg', '47_depart_avg', '49_depart_avg', '50_depart_avg', '58_depart_avg', '59_depart_avg', '60_depart_avg', '69_depart_avg', '70_depart_avg', '71_depart_avg', '72_depart_avg', '73_depart_avg', '74_depart_avg', '75_depart_avg', '77_depart_avg', '78_depart_avg', '79_depart_avg', '80_depart_avg', '81_depart_avg', '82_depart_avg', '83_depart_avg', '84_depart_avg', '87_depart_avg', '88_depart_avg', '89_depart_avg', '91_depart_avg', '92_depart_avg', '93_depart_avg', '94_depart_avg', '95_depart_avg', '96_depart_avg', '97_depart_avg', '98_depart_avg', '106_depart_avg', '107_depart_avg', '109_depart_avg', '117_depart_avg', '118_depart_avg', '121_depart_avg', '124_depart_avg', '131_depart_avg', '132_depart_avg', '133_depart_avg', '136_depart_avg', '137_depart_avg', '138_depart_avg', '141_depart_avg', '151_depart_avg', '152_depart_avg', '153_depart_avg', '154_depart_avg', '155_depart_avg', '156_depart_avg', '161_depart_avg', '162_depart_avg', '163_depart_avg', '165_depart_avg', '172_depart_avg', '173_depart_avg', '174_depart_avg', '178_depart_avg', '179_depart_avg', '182_depart_avg', '183_depart_avg', '185_depart_avg', '187_depart_avg', '194_depart_avg', '201_depart_avg', '202_depart_avg', '203_depart_avg', '206_depart_avg', '207_depart_avg', '210_depart_avg', '214_depart_avg', '215_depart_avg', '217_depart_avg', '219_depart_avg', '220_depart_avg', '221_depart_avg', '223_depart_avg', '224_depart_avg', '225_depart_avg', '226_depart_avg', '227_depart_avg', '228_depart_avg', '229_depart_avg', '230_depart_avg', '231_depart_avg', '232_depart_avg', '233_depart_avg', '234_depart_avg', 'cancel_items_sum', 'cancel_items_avg', 'buy_num_items_sum', 'buy_num_items_avg', 'given_buy_num', 'cancel10_items_sum', 'cancel10_items_avg', 'alone_items_sum', 'alone_items_avg', 'alone_items', 'cook_items_sum', 'cook_items_avg', 'cook_items', 'qoupon_avg', 'avg_qoupon', 'category_35_avg', 'category_37_avg', 'category_39_avg', 'category_40_avg', 'category_86_avg', 'category_111_avg', 'category_112_avg', 'category_135_avg', 'category_137_avg', 'category_141_avg', 'category_142_avg', 'category_143_avg', 'category_145_avg', 'category_148_avg', 'category_149_avg', 'category_150_avg', 'category_205_avg', 'category_206_avg', 'category_207_avg', 'category_208_avg', 'category_209_avg', 'category_210_avg', 'category_274_avg', 'category_275_avg', 'category_276_avg', 'category_289_avg', 'category_307_avg', 'category_310_avg', 'category_311_avg', 'category_312_avg', 'category_313_avg', 'category_316_avg', 'category_317_avg', 'category_319_avg', 'category_321_avg', 'category_328_avg', 'category_334_avg', 'category_340_avg', 'category_341_avg', 'category_342_avg', 'category_343_avg', 'category_344_avg', 'category_363_avg', 'category_365_avg', 'category_368_avg', 'category_370_avg', 'category_371_avg', 'category_372_avg', 'category_373_avg', 'category_374_avg', 'category_375_avg', 'category_376_avg', 'category_377_avg', 'category_378_avg', 'category_391_avg', 'category_392_avg', 'category_406_avg', 'category_407_avg', 'category_408_avg', 'category_410_avg', 'category_411_avg', 'category_414_avg', 'category_415_avg', 'category_416_avg', 'category_417_avg', 'category_420_avg', 'category_421_avg', 'category_422_avg', 'category_423_avg', 'category_424_avg', 'category_425_avg', 'category_426_avg', 'category_431_avg', 'category_432_avg', 'category_433_avg', 'category_436_avg', 'category_469_avg', 'category_470_avg', 'category_471_avg', 'category_472_avg', 'category_473_avg', 'category_474_avg', 'category_508_avg', 'category_509_avg', 'category_536_avg', 'category_538_avg', 'category_561_avg', 'category_562_avg', 'category_565_avg', 'category_566_avg', 'category_567_avg', 'category_568_avg', 'category_579_avg', 'category_587_avg', 'category_588_avg', 'category_589_avg', 'category_590_avg', 'category_591_avg', 'category_594_avg', 'category_602_avg', 'category_617_avg', 'category_619_avg', 'category_620_avg', 'category_621_avg', 'category_623_avg', 'category_628_avg', 'category_630_avg', 'category_631_avg', 'category_632_avg', 'category_633_avg', 'category_634_avg', 'category_636_avg', 'category_655_avg', 'category_665_avg', 'category_666_avg', 'category_669_avg', 'category_674_avg', 'category_679_avg', 'category_684_avg', 'category_708_avg', 'category_711_avg', 'category_716_avg', 'category_720_avg', 'category_724_avg', 'category_769_avg', 'category_770_avg', 'category_771_avg', 'similar_110_avg', 'similar_113_avg', 'similar_114_avg', 'similar_134_avg', 'similar_171_avg', 'similar_172_avg', 'similar_173_avg', 'similar_376_avg', 'similar_38_avg', 'similar_435_avg', 'similar_467_avg', 'similar_537_avg', 'similar_539_avg', 'similar_629_avg', 'similar_768_avg', 'category_35_given', 'category_37_given', 'category_39_given', 'category_40_given', 'category_86_given', 'category_111_given', 'category_112_given', 'category_135_given', 'category_136_given', 'category_137_given', 'category_141_given', 'category_142_given', 'category_143_given', 'category_145_given', 'category_148_given', 'category_149_given', 'category_150_given', 'category_205_given', 'category_206_given', 'category_207_given', 'category_208_given', 'category_209_given', 'category_210_given', 'category_274_given', 'category_275_given', 'category_276_given', 'category_289_given', 'category_294_given', 'category_295_given', 'category_299_given', 'category_307_given', 'category_310_given', 'category_311_given', 'category_312_given', 'category_313_given', 'category_314_given', 'category_315_given', 'category_316_given', 'category_317_given', 'category_319_given', 'category_321_given', 'category_328_given', 'category_330_given', 'category_331_given', 'category_334_given', 'category_340_given', 'category_341_given', 'category_342_given', 'category_343_given', 'category_344_given', 'category_346_given', 'category_363_given', 'category_365_given', 'category_366_given', 'category_367_given', 'category_368_given', 'category_370_given', 'category_371_given', 'category_372_given', 'category_373_given', 'category_374_given', 'category_375_given', 'category_376_given', 'category_377_given', 'category_378_given', 'category_391_given', 'category_392_given', 'category_393_given', 'category_406_given', 'category_407_given', 'category_408_given', 'category_410_given', 'category_411_given', 'category_414_given', 'category_415_given', 'category_416_given', 'category_417_given', 'category_418_given', 'category_420_given', 'category_421_given', 'category_422_given', 'category_423_given', 'category_424_given', 'category_425_given', 'category_426_given', 'category_430_given', 'category_431_given', 'category_432_given', 'category_433_given', 'category_434_given', 'category_436_given', 'category_468_given', 'category_469_given', 'category_470_given', 'category_471_given', 'category_472_given', 'category_473_given', 'category_474_given', 'category_508_given', 'category_509_given', 'category_510_given', 'category_536_given', 'category_538_given', 'category_546_given', 'category_547_given', 'category_548_given', 'category_561_given', 'category_562_given', 'category_565_given', 'category_566_given', 'category_567_given', 'category_568_given', 'category_569_given', 'category_579_given', 'category_587_given', 'category_588_given', 'category_589_given', 'category_590_given', 'category_591_given', 'category_594_given', 'category_602_given', 'category_613_given', 'category_615_given', 'category_616_given', 'category_617_given', 'category_619_given', 'category_620_given', 'category_623_given', 'category_628_given', 'category_630_given', 'category_631_given', 'category_632_given', 'category_633_given', 'category_634_given', 'category_635_given', 'category_636_given', 'category_655_given', 'category_662_given', 'category_665_given', 'category_666_given', 'category_667_given', 'category_669_given', 'category_674_given', 'category_679_given', 'category_684_given', 'category_696_given', 'category_708_given', 'category_711_given', 'category_716_given', 'category_720_given', 'category_724_given', 'category_726_given', 'category_756_given', 'category_769_given', 'category_770_given', 'category_771_given', 'similar_110_given', 'similar_113_given', 'similar_114_given', 'similar_134_given', 'similar_171_given', 'similar_172_given', 'similar_173_given', 'similar_376_given', 'similar_38_given', 'similar_435_given', 'similar_467_given', 'similar_537_given', 'similar_539_given', 'similar_629_given', 'similar_768_given']

In [78]:
# feature_cols = ['age', 'gender', 'hour', 'weekday', 'jp_holiday', 'tgif', 'num_visit', 'day', 'register_number', 'max_time_avg']
# categorical_cols = ['gender', 'weekday', 'jp_holiday', 'register_number']
categorical_cols =  ['gender', 'weekday', 'jp_holiday', 'tgif', 'month', 'day', 'register_number', 'hanakin']
numerical_cols = list(set(feature_cols) - set(categorical_cols))
# print(numerical_cols)
# train_df__ = pd.read_csv('../output/train_df.csv')
train_df = train_df[['session_id'] + feature_cols]
test_X['session_id'] = None
test_X = test_X[['session_id'] + feature_cols]

train_df.fillna(-1, inplace=True) # ラグとかでないのばかりだから、-1でいいはず...
test_X.fillna(-1, inplace=True) # ラグとかでないのばかりだから、-1でいいはず...
train_df.head()

Unnamed: 0,session_id,age,gender,hour,weekday,jp_holiday,tgif,num_visit,month,day,...,similar_172_given,similar_173_given,similar_376_given,similar_38_given,similar_435_given,similar_467_given,similar_537_given,similar_539_given,similar_629_given,similar_768_given
0,105,40.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,106,60.0,1.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,107,30.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,108,60.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,109,70.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# 標準化

In [82]:
train_df['flg'] = 0
test_X['flg'] = 1
feature_df = pd.concat([train_df, test_X])
feature_df.head()

Unnamed: 0,session_id,age,gender,hour,weekday,jp_holiday,tgif,num_visit,month,day,...,similar_173_given,similar_376_given,similar_38_given,similar_435_given,similar_467_given,similar_537_given,similar_539_given,similar_629_given,similar_768_given,flg
0,105,40.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
1,106,60.0,1.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
2,107,30.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
3,108,60.0,0.0,9,3,0,0,0,2,14,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
4,109,70.0,0.0,9,3,0,0,0,2,14,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0


In [83]:
scaler = StandardScaler() # 後でtestと一緒にして正規化する
feature_df[numerical_cols] = scaler.fit_transform(feature_df[numerical_cols])
feature_df.head()


Unnamed: 0,session_id,age,gender,hour,weekday,jp_holiday,tgif,num_visit,month,day,...,similar_173_given,similar_376_given,similar_38_given,similar_435_given,similar_467_given,similar_537_given,similar_539_given,similar_629_given,similar_768_given,flg
0,105,-0.360416,0.0,-1.7209,3,0,0,-0.834219,2,14,...,-0.202179,-0.313051,-0.181585,-0.197462,-0.159828,-0.263252,-0.217918,-0.118546,-0.217622,0
1,106,1.00624,1.0,-1.7209,3,0,0,-0.834219,2,14,...,-0.202179,-0.313051,-0.181585,-0.197462,-0.159828,-0.263252,-0.217918,-0.118546,-0.217622,0
2,107,-1.043744,0.0,-1.7209,3,0,0,-0.834219,2,14,...,-0.202179,-0.313051,-0.181585,-0.197462,-0.159828,-0.263252,-0.217918,-0.118546,-0.217622,0
3,108,1.00624,0.0,-1.7209,3,0,0,-0.834219,2,14,...,-0.202179,1.804079,-0.181585,-0.197462,-0.159828,-0.263252,-0.217918,-0.118546,-0.217622,0
4,109,1.689568,0.0,-1.7209,3,0,0,-0.834219,2,14,...,-0.202179,-0.313051,-0.181585,-0.197462,-0.159828,-0.263252,-0.217918,-0.118546,-0.217622,0


In [84]:
train_df = feature_df[feature_df['flg']==0]
test_X = feature_df[feature_df['flg']==1]

# Model

In [85]:
class CustomLinear(nn.Module):
    def __init__(self, in_features,
                 out_features,
                 bias=True, p=0.5):
        super().__init__()
        self.linear = nn.Linear(in_features,
                               out_features,
                               bias)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(p)
        
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.drop(x)
        return x

    
def make_model(input_dim, output_dim=1):
    return nn.Sequential(
                    CustomLinear(input_dim, 1024),
                    nn.Linear(1024, output_dim)
                )


# Train

In [88]:
def train_fn(target_col):
    print('*'*100)
    print('train:', target_col)

    true_ids = list(train_Y[train_Y[target_col]==1].session_id.unique())
    false_ids = list(train_Y[train_Y[target_col]!=1].session_id.unique())
    print(target_col, len(true_ids), len(false_ids))

    false_ids = random.sample(false_ids, len(true_ids) * 2) # down samplling
    sampling_ids = true_ids + false_ids
    sampling_ids = sorted(sampling_ids)

    train_idx = sampling_ids[: int(len(sampling_ids)*0.8)]
    val_idx = sampling_ids[int(len(sampling_ids)*0.8): ]

    assert set(train_idx) & set(val_idx) == set()
    train_X = torch.tensor(train_df[train_df.session_id.isin(train_idx)][feature_cols].values, dtype=torch.float32).to(DEVICE)
    train_y = torch.tensor(train_Y[train_Y.session_id.isin(train_idx)][target_col].values, dtype=torch.float32).unsqueeze(-1).to(DEVICE)

    valid_X = torch.tensor(train_df[train_df.session_id.isin(val_idx)][feature_cols].values, dtype=torch.float32).to(DEVICE)
    valid_y = torch.tensor(train_Y[train_Y.session_id.isin(val_idx)][target_col].values, dtype=torch.float32).unsqueeze(-1).to(DEVICE)
    print('train & valid: ', len(train_X), len(valid_X))

    model = make_model(len(feature_cols), 1)
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum")
#     loss_fn = nn.CrossEntropyLoss() # 微妙？という話も聞く
    optimizer = torch.optim.Adam(model.parameters())

    # Dataset
    train = torch.utils.data.TensorDataset(train_X, train_y)
    valid = torch.utils.data.TensorDataset(valid_X, valid_y)

    # DataLoader
    train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False)

    best_score = 0
    stopping_cnt = 0

    for epoch in range(EPOCHS):
        model.train()
        avg_loss = 0.

        # train
        for x_batch, y_batch in tqdm(train_loader, disable=True):
            y_pred = model(x_batch)
            loss = loss_fn(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            avg_loss += loss.item() / len(train_loader)

        # eval
        model.eval()
        valid_preds = np.zeros((valid_X.size(0)))
        avg_val_loss = 0.
        len_pred = 0
        for i, (x_batch, y_batch) in enumerate(valid_loader):
            y_pred = model(x_batch).detach()
            avg_val_loss += loss_fn(y_pred, y_batch).item() / len(valid_loader)
            if len(y_batch) == BATCH_SIZE:
                valid_preds[i * BATCH_SIZE:(i+1) * BATCH_SIZE] = sigmoid(y_pred.cpu().numpy())[:, 0]
            else:
                valid_preds[i * BATCH_SIZE: i * BATCH_SIZE + len(y_batch)] = sigmoid(y_pred.cpu().numpy())[:, 0]
        print(valid_preds)
        valid_preds = np.nan_to_num(valid_preds)
        score = roc_auc_score(valid_y.numpy()[:, 0], valid_preds)
        print(f'Epoch {epoch + 1}/{EPOCHS} \t loss={avg_loss} \t val_loss={avg_val_loss}, auc: {score}')

        if best_score < score:
            best_score = score
            torch.save(model.state_dict(), f'../nn_models/best_model_{target_col}.pth')
            stopping_cnt = 0
        else:
            stopping_cnt += 1
            if stopping_cnt > EARLY_STOPPING:
                print(f'EARLY STOPPING!!, best AUC: {best_score}')
                break
    return best_score


In [89]:
score = train_fn('38')
print('score')

****************************************************************************************************
train: 38
38 30544 374281
train & valid:  73305 18327
[0.19847552 0.21101527 0.10094588 ... 0.24322996 0.23277496 0.14868262]
Epoch 1/10 	 loss=316.2254036295289 	 val_loss=151.18986966874868, auc: 0.7534535645832345
[0.2837815  0.30582806 0.13910113 ... 0.40783998 0.27837837 0.23962276]
Epoch 2/10 	 loss=139.5494948862322 	 val_loss=142.9147369596693, auc: 0.7618861218184177
[0.15998596 0.29231882 0.07395015 ... 0.322065   0.27953574 0.11274681]
Epoch 3/10 	 loss=137.74715610995932 	 val_loss=151.81943511962885, auc: 0.7613037554112754
[0.26424587 0.37347332 0.13733813 ... 0.48576355 0.25245965 0.19076015]
Epoch 4/10 	 loss=134.04815934344043 	 val_loss=142.35090160369876, auc: 0.7635425174374494
[0.2440629  0.33673739 0.13573001 ... 0.48758584 0.29993039 0.18254505]
Epoch 5/10 	 loss=133.4074196233981 	 val_loss=143.14845816294354, auc: 0.7659998910077125
[0.19298522 0.32461435 0.0880

In [None]:
scores = {}
for target in target_category_str:
    score = train_fn(target)
    scores.append(score)

print(sum(scores) / len(scores))


# Inference

In [94]:
# test_loaderは全モデルで同じ、違うのはモデルだけ
test_X = pd.read_csv('../output/test_df.csv')
test_X = torch.tensor(test_X[feature_cols].values, dtype=torch.float32).to(DEVICE)
test = torch.utils.data.TensorDataset(test_X)
test_loader = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=False)


def inference_fn(target_col):
    model = nn.Sequential(CustomLinear(len(feature_cols), 1024), nn.Linear(1024, 1))
    model.load_state_dict(torch.load(f'../nn_models/best_model_{target_col}.pth'))
    model.eval()

    test_preds = np.zeros(len(test_X))

    for i, (x_batch,) in enumerate(test_loader):
        y_pred = model(x_batch).detach()
        if len(x_batch) == BATCH_SIZE:
            test_preds[i * BATCH_SIZE:(i+1) * BATCH_SIZE] = sigmoid(y_pred.cpu().numpy())[:, 0]
        else:
            test_preds[i * BATCH_SIZE: i * BATCH_SIZE + len(x_batch)] = sigmoid(y_pred.cpu().numpy())[:, 0]
    return test_preds

In [95]:
test_df = pd.read_csv(os.path.join(input_path, 'test.csv'))

for target in target_category_str:
    test_df[col] = inference_fn('38')

In [None]:
test_df.drop('session_id', axis=1, inplace=True)

import datetime

now = datetime.datetime.now()
now = now.strftime("%m%d_%H%M")
# test_df.round(5).to_csv(f'../output/submission_{now}_cat.csv', index=None)
test_df.round(5).to_csv(f'../pred/nn.csv', index=None)