In [1]:
import os
import glob
from joblib import Parallel, delayed
import pandas as pd
import numpy as np
import scipy as sc
from sklearn.model_selection import KFold
import warnings
warnings.filterwarnings('ignore')
pd.set_option('max_columns', 300)
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder, StandardScaler
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torch.utils.data import DataLoader

from utils import create_save_folder, EarlyStopping

os.environ["CUDA_VISIBLE_DEVICES"]="0"

fm_path = create_save_folder()

In [2]:
class denoising_model(nn.Module):
    def __init__(self, num_columns):
        super(denoising_model,self).__init__()
        self.encoder=nn.Sequential(
            nn.Linear(num_columns,256),
            # nn.BatchNorm1d(256),
            nn.SiLU(True),
            nn.Linear(256,128),
            # nn.BatchNorm1d(128),
            nn.SiLU(True),
        )
        
        self.decoder=nn.Sequential(
            nn.Linear(128,256),
            # nn.BatchNorm1d(256),
            nn.SiLU(True),
            nn.Linear(256, num_columns),
            # nn.BatchNorm1d(num_columns),
            nn.SiLU(True),
        )

    def forward(self,x):
        x=self.encoder(x)
        x=self.decoder(x)
        
        return x
    
    def encode(self, x):
        return self.encoder(x)

In [3]:
class DataSet:
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # noise = torch.randn(self.data.shape[1]).cuda()
        # clean = self.data[index]
        # dirty = self.data[index] + noise
        # return clean, dirty
        return self.data[index]


In [4]:
with open('/home/yoshikawa/work/kaggle/OPVP/output/feature_model/20210824/0/train.pkl', 'rb') as f:
    df_train = pickle.load(f)

In [5]:
train = df_train.drop(['row_id', 'target'], axis=1)
for col in train.columns.to_list():
    train[col] = train[col].fillna(train[col].mean())

scales = train.drop(["stock_id"], axis = 1).columns.to_list()

scaler = StandardScaler()
scaler.fit(train[scales])
train[scales] = scaler.transform(train[scales])
le = LabelEncoder()
le.fit(train["stock_id"])
train["stock_id"] = le.transform(train["stock_id"])

In [6]:
train_data = torch.tensor(train.values.astype(np.float32)).cuda()
train_data.shape[1]

230

In [12]:
criterion = nn.MSELoss()
kf = KFold(n_splits=5, shuffle=True, random_state=55)
epochs = 10000

cv = 0
models = []
for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)):
    print('fold: ', fold)
    print('='*100)
    train_dataset = DataSet(train_data[train_idx].cuda())
    val_dataset = DataSet(train_data[val_idx].cuda())
    train_loader = DataLoader(train_dataset, 4096, shuffle=True)
    val_loader = DataLoader(val_dataset, 4096)
    
    model = denoising_model(train_data.shape[1]).cuda()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    earlystopping = EarlyStopping(patience=10, verbose=True, path=fm_path+'/checkpoint.pth')

    for i in range(epochs):
        model.train()
        train_loss, val_loss = 0, 0

        for j, data in enumerate(train_loader):
            noise = torch.randn(data.shape).cuda()
            dirty = data + noise
            output = model(dirty)
            loss = criterion(output, data)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * data.shape[0]
        
        train_loss /= len(train_dataset)
        
        for j, data in enumerate(val_loader):
            noise = torch.randn(data.shape).cuda()
            dirty = data + noise
            output = model(dirty)
            loss = criterion(output, data)
            val_loss += loss.item() * data.shape[0]
        
        val_loss /= len(val_dataset)
        if (i+1) % 10 == 0:
            print(i+1, " epoch - train_loss: ", round(train_loss, 4), ", val_loss: ", round(val_loss, 4))
        earlystopping(val_loss, model)
        if earlystopping.early_stop:
            print("Early Stopping!!")
            break
    cv += val_loss
    model.load_state_dict(torch.load(fm_path+'/checkpoint.pth'))
    models.append(model)
cv /= 5
print("cv: ", round(cv, 4))

fold:  0
Validation loss decreased (inf --> 0.426466).  Saving model ...
Validation loss decreased (0.426466 --> 0.341160).  Saving model ...
Validation loss decreased (0.341160 --> 0.304301).  Saving model ...
Validation loss decreased (0.304301 --> 0.287294).  Saving model ...
Validation loss decreased (0.287294 --> 0.279119).  Saving model ...
Validation loss decreased (0.279119 --> 0.274693).  Saving model ...
Validation loss decreased (0.274693 --> 0.269382).  Saving model ...
Validation loss decreased (0.269382 --> 0.261899).  Saving model ...
Validation loss decreased (0.261899 --> 0.259224).  Saving model ...
10  epoch - train_loss:  0.2604 , val_loss:  0.2558
Validation loss decreased (0.259224 --> 0.255815).  Saving model ...
Validation loss decreased (0.255815 --> 0.254173).  Saving model ...
EarlyStopping counter: 1 out of 10
Validation loss decreased (0.254173 --> 0.249711).  Saving model ...
EarlyStopping counter: 1 out of 10
Validation loss decreased (0.249711 --> 0.2468

In [13]:
output = torch.zeros((train_data.shape[0], 128))
for i, model in enumerate(models):
    # train_dataset = DataSet(train_data)
    # train_loader = DataLoader(train_dataset, 4096, shuffle=False)
    # for j, data in enumerate(train_loader):
    output += model.encode(train_data).cpu() / 5
    torch.save(model.state_dict(), fm_path+'/model-'+str(i))


In [14]:
output

tensor([[1.4571, 1.2547, 1.9901,  ..., 2.6923, 1.7736, 0.6860],
        [1.3585, 0.8547, 1.8584,  ..., 2.7255, 1.8111, 0.5490],
        [1.3632, 1.1603, 2.0882,  ..., 2.9494, 1.4200, 1.2114],
        ...,
        [3.2910, 1.8508, 4.6874,  ..., 2.9970, 3.4341, 2.3833],
        [3.0671, 1.9353, 4.5361,  ..., 2.8554, 2.9525, 2.6529],
        [3.1853, 1.4961, 4.5273,  ..., 2.9214, 3.0446, 2.6488]],
       grad_fn=<AddBackward0>)

In [15]:
output.shape

torch.Size([428932, 128])

In [16]:
df_output = pd.DataFrame(output.detach().numpy())
df_output.columns = ['DAE_'+ str(i) for i in df_output.columns]
df_output.describe()

Unnamed: 0,DAE_0,DAE_1,DAE_2,DAE_3,DAE_4,DAE_5,DAE_6,DAE_7,DAE_8,DAE_9,DAE_10,DAE_11,DAE_12,DAE_13,DAE_14,DAE_15,DAE_16,DAE_17,DAE_18,DAE_19,DAE_20,DAE_21,DAE_22,DAE_23,DAE_24,DAE_25,DAE_26,DAE_27,DAE_28,DAE_29,DAE_30,DAE_31,DAE_32,DAE_33,DAE_34,DAE_35,DAE_36,DAE_37,DAE_38,DAE_39,DAE_40,DAE_41,DAE_42,DAE_43,DAE_44,DAE_45,DAE_46,DAE_47,DAE_48,DAE_49,DAE_50,DAE_51,DAE_52,DAE_53,DAE_54,DAE_55,DAE_56,DAE_57,DAE_58,DAE_59,DAE_60,DAE_61,DAE_62,DAE_63,DAE_64,DAE_65,DAE_66,DAE_67,DAE_68,DAE_69,DAE_70,DAE_71,DAE_72,DAE_73,DAE_74,DAE_75,DAE_76,DAE_77,DAE_78,DAE_79,DAE_80,DAE_81,DAE_82,DAE_83,DAE_84,DAE_85,DAE_86,DAE_87,DAE_88,DAE_89,DAE_90,DAE_91,DAE_92,DAE_93,DAE_94,DAE_95,DAE_96,DAE_97,DAE_98,DAE_99,DAE_100,DAE_101,DAE_102,DAE_103,DAE_104,DAE_105,DAE_106,DAE_107,DAE_108,DAE_109,DAE_110,DAE_111,DAE_112,DAE_113,DAE_114,DAE_115,DAE_116,DAE_117,DAE_118,DAE_119,DAE_120,DAE_121,DAE_122,DAE_123,DAE_124,DAE_125,DAE_126,DAE_127
count,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0,428932.0
mean,2.024674,1.408132,2.937426,2.922235,1.404105,2.793133,2.509424,0.836562,3.257326,3.626128,1.956815,1.648261,2.811153,3.20883,1.484894,2.376671,3.860445,1.67905,1.981682,1.66134,2.739988,1.137162,2.801373,2.568943,3.491256,4.054435,1.839545,2.412157,2.417036,0.990277,0.898654,1.914059,2.411544,2.159241,2.735686,2.314098,1.951012,1.348031,3.207363,2.653598,1.316775,3.112838,1.679902,3.475783,2.062476,0.473169,0.911093,2.108604,1.854478,2.301869,3.907174,2.412498,2.093194,2.144174,2.688028,2.562392,0.542321,2.804337,4.070835,2.631412,3.579354,0.839137,2.388687,3.143724,1.632429,2.864942,5.138174,3.223232,1.306428,3.053071,0.973733,1.451537,0.411263,2.231162,2.314613,2.73199,1.031546,2.230681,3.490454,1.545617,2.844125,2.217495,2.502552,0.874606,1.035058,4.537579,1.466348,2.93563,2.836405,1.596715,2.341712,1.80735,2.057317,2.665547,3.989386,3.086629,3.927931,1.440984,2.123539,1.637443,2.927938,2.510748,1.107941,3.701484,4.635833,0.525721,1.134774,2.043727,3.750197,3.535721,1.566113,1.331328,2.827906,1.259785,3.121304,1.35991,1.340247,3.293372,4.014331,2.090719,1.842517,1.283807,3.29571,2.128874,2.004336,2.287117,1.753305,1.867003
std,1.017644,0.951307,1.306224,1.491263,0.769466,1.364367,1.119864,0.78635,1.459839,1.64188,1.332849,0.882851,1.600111,1.297801,0.69305,0.802069,1.967745,0.899305,1.037621,0.839485,1.533602,1.220607,1.484159,1.36032,1.689755,1.93158,0.98372,1.196447,1.160527,0.896995,0.690273,0.717938,1.295802,1.37452,1.28001,1.200422,0.975409,0.463264,1.424147,0.825631,0.728766,1.511979,0.897245,1.939461,1.334288,0.383691,0.545298,1.108517,0.976878,1.215984,2.184227,0.924319,0.891804,1.138509,1.355214,1.2939,0.511135,1.157645,1.971765,1.186085,2.114194,0.642434,1.161383,1.706791,1.047712,1.429129,3.182129,1.50265,0.901202,1.649843,0.700494,0.570764,0.589004,0.948936,1.210597,1.108601,1.13503,0.820134,2.18601,0.730072,1.240039,1.463566,1.101996,0.394449,0.858471,1.737618,0.901441,1.402081,1.082996,0.724913,0.810419,0.697783,0.996474,0.818546,1.682098,1.141823,1.934143,1.159958,1.182307,0.881222,1.357506,1.25662,1.149014,1.779679,2.497866,0.577746,0.657375,0.918154,1.617673,1.759838,0.587936,0.739148,1.285038,0.988571,1.126258,0.78708,0.931314,1.330197,1.75618,1.174962,0.873349,1.243757,1.718334,0.978994,0.886139,0.729555,0.748291,0.698339
min,0.062018,-0.02786,0.812468,-0.017151,-0.026841,0.413867,-0.107802,-0.04365,0.407278,0.54384,0.428674,0.009789,0.00531,0.384764,0.260447,0.561378,0.375714,0.428371,0.042687,-0.126296,-0.026918,-0.087905,0.20829,0.238699,0.296108,0.555157,0.144452,0.216937,-0.074775,-0.19384,-0.18503,0.496002,0.011655,-0.202392,0.031893,0.503008,0.424149,-0.133708,0.626102,0.554607,-0.091645,0.257875,0.469841,0.014262,0.167054,-0.16934,-0.061915,0.316752,-0.034165,0.728749,0.165143,0.533683,0.268449,0.171979,0.017963,0.181411,-0.19809,0.564466,0.237165,0.425648,-0.063608,-0.127744,0.207885,0.154899,0.330341,0.412181,-0.07104,0.522588,0.139419,-0.105705,-0.165483,-0.135702,-0.238229,0.335615,0.121074,0.733622,-0.223446,0.040887,-0.110965,0.060668,0.162604,-0.02215,0.490635,-0.14819,-0.103424,1.658968,0.222788,0.346246,0.520826,0.29463,0.515861,0.408061,0.029962,1.165267,0.690166,1.097157,0.406276,-0.223543,0.114398,0.034191,-0.10388,-0.037165,-0.163817,0.524843,0.522679,-0.193335,0.046562,0.12973,0.644681,0.51251,0.019789,0.183774,0.383844,-0.138195,0.060404,0.166807,-0.034977,0.842711,0.526916,-0.187488,0.219643,-0.024299,0.127327,0.236253,0.017489,0.418045,0.155278,0.471288
25%,1.274684,0.785676,1.880312,1.679957,0.818609,1.780044,1.592092,0.425845,2.09415,2.173276,1.199851,0.996286,1.575969,2.263134,1.004348,1.771035,2.307271,1.047377,1.173922,1.087964,1.556829,0.332557,1.56957,1.570461,2.089021,2.467633,1.042034,1.388354,1.617658,0.300286,0.371273,1.38817,1.390193,1.114027,1.63956,1.523963,1.311072,1.08304,2.038432,2.012628,0.785273,1.942211,1.123655,1.88547,1.025364,0.178673,0.529758,1.286473,1.170368,1.611055,2.079033,1.692532,1.404396,1.185068,1.598881,1.625813,0.204318,1.878807,2.343534,1.77164,1.717301,0.369828,1.607864,1.654081,1.023143,1.806286,2.388594,1.924301,0.805316,1.677071,0.491612,1.074723,-0.027381,1.455519,1.415537,1.881561,0.212787,1.573933,1.508174,1.00039,1.88621,1.089769,1.688613,0.636378,0.478779,3.043893,0.938102,1.788649,2.082508,1.084822,1.712558,1.264002,1.286618,2.085054,2.641004,2.131509,2.363745,0.687904,1.334411,0.922157,1.899177,1.418818,0.3268,1.998526,2.311509,0.162656,0.76188,1.360384,2.209044,1.993971,1.214948,0.83567,1.837018,0.490853,2.3022,0.884167,0.751478,2.205729,2.557132,1.258192,1.128842,0.519243,1.866449,1.352503,1.392505,1.76894,1.197701,1.392847
50%,1.903999,1.213563,2.695305,2.857519,1.307921,2.602454,2.382053,0.670372,3.009487,3.387092,1.621315,1.562461,2.67205,3.129374,1.344065,2.274192,3.604648,1.538631,1.930358,1.532042,2.713564,0.840364,2.743024,2.386667,3.336596,4.0107,1.742618,2.29847,2.378002,0.876855,0.77497,1.765796,2.320869,2.010129,2.597027,2.073592,1.84484,1.268389,3.099233,2.536582,1.237081,2.995427,1.488312,3.34238,1.793241,0.366063,0.787764,1.947137,1.761132,2.045162,3.674017,2.281202,1.987781,2.05508,2.742975,2.437018,0.472236,2.701438,3.912577,2.448956,3.539217,0.703858,2.210787,3.017775,1.361162,2.649114,5.012832,3.074674,1.132413,2.928179,0.858776,1.347187,0.218637,2.169994,2.185149,2.533827,0.692259,2.032061,3.251299,1.53562,2.771674,2.049525,2.401886,0.798529,0.907009,4.226443,1.280626,2.825978,2.702868,1.48442,2.24947,1.751781,2.032765,2.512692,3.829549,2.932856,3.744698,1.293214,1.976575,1.541849,2.837799,2.649225,0.717668,3.611367,4.505921,0.361123,0.980042,2.021122,3.601254,3.388418,1.532064,1.18233,2.772798,1.066043,3.10476,1.203931,1.21443,3.117557,3.994386,1.973584,1.644768,0.856074,3.288159,2.051226,2.061236,2.236703,1.635326,1.822362
75%,2.549024,1.73755,3.884579,4.069175,1.938737,3.44742,3.468974,1.018479,4.554482,5.058271,2.203326,2.154535,3.995942,4.280066,1.80923,2.897684,5.609864,2.103645,2.668835,2.080106,3.823456,1.669661,3.917695,3.408428,4.82183,5.71994,2.491667,3.422202,3.212009,1.45416,1.361463,2.366352,3.420475,2.952448,3.853345,2.790218,2.371203,1.510947,4.200063,3.229191,1.829378,4.152659,1.966339,5.220947,2.836496,0.700508,1.175785,2.720881,2.41754,2.575374,5.706619,2.985958,2.711728,2.968689,3.683208,3.278931,0.752838,3.697036,5.807392,3.338212,5.447196,1.166043,3.021713,4.654884,1.9043,3.828641,8.047491,4.484698,1.536645,4.429292,1.284847,1.782634,0.629469,2.975945,3.028923,3.476605,1.599723,2.781033,5.467836,2.055202,3.762518,2.999243,3.145068,1.021765,1.380819,5.890095,1.743795,4.014479,3.402704,1.95223,2.918702,2.3038,2.676783,3.119419,5.389834,3.895959,5.52986,1.909896,2.741665,2.222348,4.007038,3.35872,1.541842,5.157278,6.714577,0.701417,1.279493,2.652293,4.97568,5.041458,1.869686,1.622834,3.781793,1.655608,3.906224,1.60147,1.690426,4.331448,5.293569,2.724885,2.525229,1.623036,4.600068,2.769329,2.696338,2.649653,2.229302,2.230606
max,18.02446,22.601524,27.918371,28.655079,23.750246,26.07711,16.938271,52.428352,17.541622,18.071888,37.471142,34.626419,24.344589,19.606474,19.50095,33.188717,37.896343,39.378742,55.689831,19.999136,18.624996,64.264793,31.533251,34.17728,36.544731,33.914032,39.658615,18.896595,35.787132,23.288439,27.969589,25.4825,26.554855,29.295904,16.079582,38.024303,49.020081,16.002151,31.777962,15.599897,16.129927,25.712074,24.844378,22.816452,28.650955,14.522635,19.833723,26.017067,26.406811,52.894966,20.034935,25.888702,21.011089,16.331638,28.040071,36.123993,31.370787,43.852356,37.088123,51.38628,24.444849,12.860329,27.972198,24.344355,44.405643,29.702198,19.409576,30.127403,43.5494,39.111294,25.76815,13.165955,12.171367,25.081585,36.03661,45.021458,30.732347,21.328964,14.233596,20.204519,25.459734,33.055,25.477468,16.207867,58.689289,33.643768,39.74609,23.192484,42.448872,18.263023,22.673632,26.464434,10.437089,24.937464,16.69725,30.774929,28.307819,33.689285,38.154816,41.86359,23.254513,54.6814,22.906368,17.541052,19.323534,26.69101,29.991415,20.578934,30.252144,28.147213,29.936543,17.901352,15.673507,18.643692,22.426868,32.037056,15.735491,16.876789,27.618805,41.005379,39.843224,27.570932,36.73328,14.316,27.130157,24.707687,13.929532,41.308521


In [19]:
df = pd.concat([df_train, df_output], axis=1)
pickle.dump(df, open(os.path.join(fm_path, "train.pkl"), 'wb'))