In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm, trange
from sklearn.preprocessing import LabelEncoder
import pickle
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset, Subset
import torch_optimizer as optim
import os
import gc

In [2]:
experiments_path = './experiments/amex_transformer_pretrain_v4/'
if not os.path.exists(experiments_path):
    os.makedirs(experiments_path)

In [3]:
X = np.load("./data/X2.npy")
y = np.load("./data/y.npy")
customer_list = np.load("./data/customer_list.npy")
X_fe = np.load("./data/X_fe.npy")

In [5]:
X.shape, y.shape, X_fe.shape

((1383534, 13, 188), (1383534,), (1383534, 1095))

In [6]:
features = pd.read_csv("./data/train_data.csv",index_col=0,nrows=1).columns[2:]
cat_features = ['B_30', 'B_38', 'D_114', 'D_116', 'D_117', 'D_120', 'D_126', 'D_63', 'D_64', 'D_66', 'D_68']
dense_features = [col for col in features if col not in cat_features]
features = cat_features + dense_features
features_group = {}
for i, col in enumerate(features):
    g = col[0]
    if g not in features_group:
        features_group[g] = {}
        if col in cat_features:
            features_group[g]['cat'] = [i]
        else:
            features_group[g]['dense'] = [i]
    else:
        if col in cat_features:
            if 'cat' in features_group[g]:
                features_group[g]['cat'].append(i)
            else:
                features_group[g]['cat'] = [i]
        else:
            if 'dense' in features_group[g]:
                features_group[g]['dense'].append(i)
            else:
                features_group[g]['dense'] = [i] 

In [7]:
for a in features_group:
    for b in features_group[a]:
        print(a,b,len(features_group[a][b]))

B cat 2
B dense 38
D cat 9
D dense 87
P dense 3
R dense 28
S dense 21


In [8]:
np.nanmax(X[...,:11].reshape(-1,11),0)+1

array([4., 8., 3., 3., 8., 3., 4., 6., 5., 3., 8.], dtype=float32)

In [9]:
BATCH_SIZE = 1024
EPOCHS = 120
device = torch.device('cuda:1')
dataset = TensorDataset(torch.Tensor(X), torch.Tensor(X_fe))

In [10]:
param = {'d_model': 768,
         'emb_dim': 4,
         'n_layers': 6,
         'n_heads': 4,
         'activation': 'relu',
         'transformer_act': 'relu',
         'use_cls': False,
         'input_norm': False,
         'input_layers': 0,
         'tanh_scale': 0.3450840441113073,
         'input_dropout': 0.10803114983077852,
         'hidden_dropout': 0.23771720623629086,
         'final_dropout': 0.1803505437404746,
         'transformer_dropout': 0.24608225028381883,
         'output_layers': 'mlp2',
         'pe_std': 0.7232348735328199,
         'optimizer': 'Lamb',
         'lr': 0.01,
         'weight_decay': 0.051712665649902206,
         'optimizer_alpha': 0.07931217763278503,
         'optimizer_beta': 0.0070573347096303885}

In [11]:
class TanhEstimator(nn.Module):
    def __init__(self, inp_size, tanh_scale=0.1):
        super().__init__()
        self.alpha = nn.Parameter(tanh_scale * torch.ones(inp_size))
        self.beta = nn.Parameter(torch.zeros(inp_size))
    def forward(self, inp):
        x = torch.tanh(self.alpha * inp + self.beta)
        return x

class AmexModel(nn.Module):
    def __init__(self, params):
        super().__init__()
        d_model = params['d_model']
        emb_dim = params['emb_dim']
        n_layers = params['n_layers']
        n_heads = 2**params['n_heads']
        
        if params['activation'] == 'relu':
            activation = nn.ReLU()
        elif params['activation'] == 'gelu':
            activation = nn.GELU()
        elif params['activation'] == 'mish':
            activation = nn.Mish()
        
        self.use_cls = params['use_cls']
        
        self.n_cat = [4, 8, 3, 3, 8, 3, 4, 6, 5, 3, 8]
        self.n_dense = 177
        self.features_group = features_group
        self.inp_emb = nn.ModuleDict()
        for key1 in self.features_group:
            self.inp_emb[key1] = nn.ModuleDict()
            for key2 in self.features_group[key1]:
                if key2 == 'cat':
                    self.inp_emb[key1][key2] = nn.ModuleList()
                    for i in self.features_group[key1][key2]:
                        self.inp_emb[key1][key2].append(nn.Embedding(self.n_cat[i],emb_dim))
                else:
                    d = len(self.features_group[key1][key2])
                    if 'cat' in self.features_group[key1]:
                        d += len(self.features_group[key1]['cat']) * emb_dim
                    self.inp_emb[key1][key2] = nn.Sequential(nn.Linear(d, d_model),
                                                             nn.Dropout(0.1),
                                                             nn.Mish(),
                                                             nn.Linear(d_model,d_model))
        
        if params['input_norm']:
            self.norm = nn.BatchNorm1d(self.n_dense)
        else:
            self.norm = nn.Identity()
            
        self.dense_norm = TanhEstimator(self.n_dense, params['tanh_scale'])
        self.post_norm = nn.BatchNorm1d(len(self.features_group) * d_model)
        self.proj = [nn.Dropout(params['input_dropout']),nn.Linear(len(self.features_group) * d_model, d_model)]
        for _ in range(params['input_layers']):
            self.proj.extend([nn.Dropout(params['hidden_dropout']), activation, nn.Linear(d_model, d_model)])
        self.proj = nn.Sequential(*self.proj)
        
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model,n_heads,4*d_model,
                                                                            params['transformer_dropout'],
                                                                            activation=params['transformer_act'],
                                                                            norm_first=True,batch_first=True)
                                                 ,n_layers)
        
        if params['output_layers'] == 'linear':
            self.fc = nn.Sequential(nn.Dropout(params['final_dropout']), nn.Linear(d_model, 1))
        elif params['output_layers'] == 'mlp':
            self.fc = nn.Sequential(nn.Dropout(params['hidden_dropout']), nn.Linear(d_model, d_model), activation,
                                    nn.Dropout(params['final_dropout']), nn.Linear(d_model, 1))
        elif params['output_layers'] == 'mlp2':
            self.fc = nn.Sequential(nn.Dropout(params['hidden_dropout']), nn.Linear(d_model, 4*d_model), activation,
                                    nn.Dropout(params['final_dropout']), nn.Linear(4*d_model, 1))
        
        self.ae_fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(d_model, 1095))
        
        self.pe = nn.Parameter(torch.empty([13,d_model]))
        nn.init.normal_(self.pe, std=params['pe_std'])
        
        self.cls = nn.Parameter(torch.empty(d_model))
        nn.init.normal_(self.cls)
        
            
    def forward(self, inp):
        missing_nodes_mask = torch.all(torch.isnan(inp),dim=-1)
        inp[torch.isnan(inp)] = 0
        
        inp_cat = inp[...,:len(self.n_cat)]
        inp_dense = inp[...,len(self.n_cat):]
#         inp_dense = self.dense_norm(inp_dense)
#         inp_dense = self.norm(inp_dense.transpose(1,2)).transpose(1,2)
        inp = torch.cat([inp_cat,inp_dense],dim=-1)
        X = []
        for key1 in self.features_group:
            if 'cat' in self.features_group[key1]:
                X_list = [inp[...,self.features_group[key1]['dense']]]
                for i, idx in enumerate(self.features_group[key1]['cat']):
                    X_list.append(self.inp_emb[key1]['cat'][i](inp[...,idx].long()))
                X_list = torch.cat(X_list,dim=-1)
                X.append(self.inp_emb[key1]['dense'](X_list))
            else:
                X.append(self.inp_emb[key1]['dense'](inp[...,self.features_group[key1]['dense']]))
        X = torch.cat(X,dim=-1)
        X = X.permute(0,2,1)
        X = self.post_norm(X)
        X = X.permute(0,2,1)
        X = self.proj(X)
        X = X + self.pe
        
        if self.use_cls:
            X = torch.cat([X,self.cls.reshape(1,1,-1).repeat(len(X),1,1)],dim=1)
            mask = torch.cat([missing_nodes_mask,torch.zeros([len(X),1],device=X.device).bool()],dim=1)
            X = self.transformer(X, src_key_padding_mask=mask)
        else:
            X = self.transformer(X, src_key_padding_mask=missing_nodes_mask)
        
        X = X[:,-1]
        
        y = self.ae_fc(X)
        return y

In [12]:
criterion = nn.HuberLoss(delta=3.0)

def train_one_epoch(model, optimizer, scheduler, train_dataloader, device = torch.device('cpu')):
    model.train()
    MA_loss = 0
    count = 0
    for X,y in tqdm(train_dataloader, leave=False):
        X = X.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        
        pred = model(X)
        mask = ~torch.isnan(y)
        loss = criterion(pred[mask],y[mask])
        loss.backward()
        optimizer.step()
        scheduler.step()
        MA_loss += loss.item() * mask.float().sum()
        count += mask.float().sum()
    MA_loss /= count
    return MA_loss

In [13]:
model_path = experiments_path + f"model.pt"
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=0,shuffle=True,drop_last=True)
model = AmexModel(param).to(device)
if param['optimizer'] == 'Lamb':
    optimizer = optim.Lamb(model.parameters(),
                           lr=param['lr'],
                           weight_decay=param['weight_decay'],
                           betas = (1 - param['optimizer_alpha'], 1 - param['optimizer_beta']))
elif param['optimizer'] == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(),
                                   lr=param['lr'],
                                   weight_decay=param['weight_decay'],
                                   betas = (1 - param['optimizer_alpha'], 1 - param['optimizer_beta']))
elif param['optimizer'] == 'Ranger':
    optimizer = optim.Ranger(model.parameters(),
                               lr=param['lr'],
                               weight_decay=param['weight_decay'],
                               betas = (1 - param['optimizer_alpha'], 1 - param['optimizer_beta']))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, EPOCHS*len(train_loader))

best_score = -1.0
for epoch in trange(EPOCHS):
    train_loss = train_one_epoch(model, optimizer, scheduler, train_loader, device)
    print(f"epoch {epoch}")
    print(f"train_loss {train_loss}")

  0%|          | 0/120 [00:00<?, ?it/s]

  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 0
train_loss 0.1177452951669693


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 1
train_loss 0.07917426526546478


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 2
train_loss 0.07464205473661423


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 3
train_loss 0.07237251847982407


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 4
train_loss 0.07083830237388611


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 5
train_loss 0.06975482404232025


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 6
train_loss 0.0689399465918541


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 7
train_loss 0.06820686161518097


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 8
train_loss 0.06763049960136414


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 9
train_loss 0.06712858378887177


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 10
train_loss 0.06665902584791183


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 11
train_loss 0.06617778539657593


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 12
train_loss 0.06579289585351944


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 13
train_loss 0.06540705263614655


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 14
train_loss 0.06506569683551788


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 15
train_loss 0.06479049474000931


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 16
train_loss 0.06451376527547836


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 17
train_loss 0.0642123892903328


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 18
train_loss 0.06396754086017609


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 19
train_loss 0.06371521949768066


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 20
train_loss 0.06350404769182205


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 21
train_loss 0.06326939165592194


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 22
train_loss 0.06304941326379776


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 23
train_loss 0.06285714358091354


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 24
train_loss 0.06265580654144287


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 25
train_loss 0.06250018626451492


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 26
train_loss 0.062264930456876755


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 27
train_loss 0.06205016002058983


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 28
train_loss 0.06186039745807648


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 29
train_loss 0.06164997071027756


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 30
train_loss 0.06150055676698685


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 31
train_loss 0.06133658438920975


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 32
train_loss 0.06114877387881279


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 33
train_loss 0.06093457713723183


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 34
train_loss 0.06080249696969986


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 35
train_loss 0.060595955699682236


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 36
train_loss 0.060415104031562805


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 37
train_loss 0.06017335504293442


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 38
train_loss 0.06001528725028038


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 39
train_loss 0.05981741100549698


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 40
train_loss 0.05971129238605499


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 41
train_loss 0.05954617261886597


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 42
train_loss 0.05937204882502556


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 43
train_loss 0.05920759588479996


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 44
train_loss 0.05906427279114723


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 45
train_loss 0.058937471359968185


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 46
train_loss 0.05879975110292435


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 47
train_loss 0.05863761156797409


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 48
train_loss 0.05853120610117912


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 49
train_loss 0.058353036642074585


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 50
train_loss 0.058209002017974854


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 51
train_loss 0.058077264577150345


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 52
train_loss 0.05792098492383957


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 53
train_loss 0.05777059867978096


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 54
train_loss 0.057596489787101746


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 55
train_loss 0.057464323937892914


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 56
train_loss 0.05727843940258026


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 57
train_loss 0.057184044271707535


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 58
train_loss 0.05704045295715332


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 59
train_loss 0.05683877319097519


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 60
train_loss 0.056730274111032486


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 61
train_loss 0.05655817314982414


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 62
train_loss 0.05645613372325897


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 63
train_loss 0.056306276470422745


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 64
train_loss 0.05616254359483719


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 65
train_loss 0.056019142270088196


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 66
train_loss 0.05589974299073219


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 67
train_loss 0.055804938077926636


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 68
train_loss 0.05562739446759224


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 69
train_loss 0.05547250807285309


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 70
train_loss 0.055401772260665894


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 71
train_loss 0.05522093549370766


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 72
train_loss 0.055060192942619324


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 73
train_loss 0.05493023246526718


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 74
train_loss 0.05476398020982742


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 75
train_loss 0.05465114489197731


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 76
train_loss 0.0545010045170784


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 77
train_loss 0.05438929796218872


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 78
train_loss 0.054273057729005814


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 79
train_loss 0.05412241816520691


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 80
train_loss 0.054007530212402344


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 81
train_loss 0.053913213312625885


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 82
train_loss 0.053762052208185196


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 83
train_loss 0.05364999547600746


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 84
train_loss 0.05350233614444733


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 85
train_loss 0.05340828374028206


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 86
train_loss 0.05328676849603653


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 87
train_loss 0.05321041867136955


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 88
train_loss 0.05311635136604309


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 89
train_loss 0.05299907922744751


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 90
train_loss 0.05289311707019806


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 91
train_loss 0.05278271809220314


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 92
train_loss 0.052706267684698105


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 93
train_loss 0.05262627825140953


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 94
train_loss 0.05249564349651337


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 95
train_loss 0.05244513973593712


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 96
train_loss 0.05236612632870674


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 97
train_loss 0.052221257239580154


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 98
train_loss 0.05217884108424187


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 99
train_loss 0.05212749168276787


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 100
train_loss 0.052026063203811646


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 101
train_loss 0.05195818468928337


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 102
train_loss 0.051908913999795914


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 103
train_loss 0.05182311683893204


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 104
train_loss 0.05180096626281738


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 105
train_loss 0.05172942951321602


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 106
train_loss 0.05170038342475891


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 107
train_loss 0.05164804309606552


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 108
train_loss 0.0515800416469574


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 109
train_loss 0.0515456460416317


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 110
train_loss 0.05153533071279526


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 111
train_loss 0.05149427428841591


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 112
train_loss 0.05147862061858177


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 113
train_loss 0.0514226108789444


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 114
train_loss 0.051457665860652924


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 115
train_loss 0.05143538489937782


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 116
train_loss 0.051399748772382736


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 117
train_loss 0.051378071308135986


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 118
train_loss 0.05140624940395355


  0%|          | 0/1351 [00:00<?, ?it/s]

epoch 119
train_loss 0.05143903195858002


In [14]:
torch.save(model.state_dict(), model_path)