# Install

# Import

In [1]:
import sys
import os
import gc
import copy
import yaml
import pickle
import random
import joblib 
import shutil
from time import time
import typing as tp
from pathlib import Path
import psutil

import numpy as np
import pandas as pd
import scipy

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from sklearn.metrics import average_precision_score as APS
import duckdb


import torch
import torchvision
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp
from torch.nn import BCELoss
from torch.utils.data import Dataset


import timm
from mamba_ssm import Mamba
from transformers import AutoModel, AutoTokenizer

import albumentations as A
from albumentations.pytorch import ToTensorV2


# use one device only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
con = duckdb.connect()

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.11 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


20000_50per_CLM.parquet

In [2]:
class CFG:
    TEST_No = 1
    TEST_NUM = int(1674896/32 * TEST_No)
    TEST_OFFSET = int(TEST_NUM * (TEST_No-1))
    TRAIN_CLM_PATH = Path('/root/Kaggle_NeurIPS2024/data/processed/20000_50per_CLM.parquet')
    TEST_CLM_PATH = Path(f'/root/Kaggle_NeurIPS2024/data/processed/test_CLM_{TEST_OFFSET}_to_{TEST_NUM}.parquet')
    TRAIN_ENC_PATH = Path('/root/Kaggle_NeurIPS2024/data/external/train_enc.parquet')
    TEST_ENC_PATH = Path('/root/Kaggle_NeurIPS2024/data/external/test_enc.parquet')
    TRAIN_PATH = Path('/root/Kaggle_NeurIPS2024/data/raw/train.parquet')
    TEST_PATH = Path('/root/Kaggle_NeurIPS2024/data/raw/test.parquet')
    folds = 2
    max_epoch = 9             # number of max epoch. 1epoch means going around the training dataset.
    batch_size = 2048           # batch size. Number of samples passed to the network in one training step
    lr = 1.0e-03              # learning rate. determine step size when updating model's weight
    weight_decay = 1.0e-02    # weight decay. Append regularization term for prevent over fitting
    es_patience = 5           # Timing for early stopping. If there is no improvement within this number of epochs, training will be stopped early.
    seed = 1086               # Random number seed
    deterministic = True      # Enable/disable deterministic behavior. If enabled, the program will produce the same results every time it starts with the same initial conditions and inputs.
    enable_amp = False        # Enable/disable Automatic Mixed Precision. Optimizations for floating point etc.
    device = "cuda" 
    n_classes = 3
    clm = False
    enc = True

In [3]:
# train = pd.read_parquet(CFG.TRAIN_ENC_PATH)
# test = pd.read_parquet(CFG.TEST_ENC_PATH)
# train.head()
# print(len(train))
# print(len(test))

In [4]:
# display(train.head())
# print(len(train))
# print(len(test))

In [5]:
# train = con.query(f"""(SELECT *
#                         FROM parquet_scan('{CFG.TRAIN_PATH}')
#                         LIMIT 60000)""").df()
# test = con.query(f"""(SELECT *
#                         FROM parquet_scan('{CFG.TRAIN_ENC_PATH}')
#                         LIMIT 1674896)""").df()

In [6]:
if CFG.clm:
    train_clm = con.query(f"""(SELECT *
                            FROM parquet_scan('{CFG.TRAIN_CLM_PATH}')
                            )""").df()

In [7]:
# about 2min 30s with 1/10 data
if CFG.enc:
    train_enc = con.query(f"""(SELECT *
                            FROM parquet_scan('{CFG.TRAIN_ENC_PATH}')
                            ORDER BY random()
                            LIMIT {int(295246830 / 3 / 10)}
                            )""").df()
    
    # train_enc = con.query(f"""(SELECT *
    #                         FROM parquet_scan('{CFG.TRAIN_ENC_PATH}')
    #                         WHERE binds = 0
    #                         ORDER BY random()
    #                         LIMIT {int((295246830 / 3 / 100)*0.9)})
    #                         UNION ALL
    #                         (SELECT *
    #                         FROM parquet_scan('{CFG.TRAIN_ENC_PATH}')
    #                         WHERE binds = 1
    #                         ORDER BY random()
    #                         LIMIT {int((295246830 / 3 / 100)*0.1)})""").df()

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [8]:
test = con.query(f"""(SELECT *
                        FROM parquet_scan('{CFG.TEST_PATH}')
                        LIMIT 10000)""").df()

In [9]:
if CFG.clm:
    test_clm = con.query(f"""(SELECT *
                            FROM parquet_scan('{CFG.TEST_CLM_PATH}')
                            )""").df()

In [10]:
if CFG.enc:
    test_enc = con.query(f"""(SELECT *
                        FROM parquet_scan('{CFG.TEST_ENC_PATH}')
                        )""").df()

In [11]:
from sklearn.model_selection import KFold

def split_fold(df:pd.DataFrame):
    # config
    N_FOLDS = CFG.folds
    RANDAM_SEED = 42
    df['fold'] = -1

    # object
    skf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDAM_SEED)

    for i, (train_index, test_index) in enumerate(skf.split(df)):
        df.loc[test_index, 'fold'] = i
    
    return df
if CFG.clm:
    train_clm = split_fold(train_clm)
if CFG.enc:
    train_enc = split_fold(train_enc)

In [12]:
if CFG.clm:
    train_clm.head()
if CFG.enc:
    train_enc.head()

In [13]:
if CFG.clm:
    display(train_clm.head())
    display(train_clm.tail())
if CFG.enc:
    display(train_enc.head())
    display(train_enc.tail())

Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141,bind1,bind2,bind3,fold
0,8,22,8,8,29,8,3,3,5,32,17,8,8,17,26,28,19,33,29,30,2,32,19,33,12,27,35,12,17,33,8,8,35,18,12,12,12,35,12,18,26,28,19,35,12,17,33,12,18,12,12,12,12,17,8,33,4,8,8,33,17,8,19,8,8,4,19,12,18,19,35,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,8,28,12,27,12,12,12,12,17,8,17,26,28,19,33,29,30,2,32,19,12,27,33,12,27,35,12,17,33,8,12,18,12,12,12,12,12,18,33,18,8,8,28,8,8,18,19,35,12,17,33,8,12,18,35,35,35,17,8,19,12,18,8,17,7,19,7,19,35,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
2,28,26,8,17,8,8,33,12,27,35,12,17,33,8,8,12,18,35,12,12,17,31,9,19,10,18,19,35,12,17,33,8,18,17,8,17,26,28,19,33,29,30,2,32,19,8,8,8,8,18,19,35,27,19,33,12,27,12,12,35,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,8,12,27,12,12,17,28,19,12,12,17,8,19,12,27,33,12,27,35,12,17,33,8,12,18,35,12,4,12,12,12,12,12,4,12,17,26,28,19,29,35,5,32,18,19,35,12,17,33,12,18,12,17,28,8,17,7,19,17,7,19,7,19,12,12,12,12,18,8,17,26,28,19,33,29,30,2,32,19,35,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,8,35,27,12,12,17,8,18,17,8,33,12,4,35,12,17,33,8,12,25,12,12,35,35,25,8,19,35,12,17,33,29,8,3,5,32,17,8,12,25,12,12,12,17,8,1,19,12,17,8,1,19,12,25,19,8,17,26,28,19,33,29,30,2,32,19,35,4,19,8,8,8,28,18,19,12,35,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1


Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141,bind1,bind2,bind3,fold
9841556,28,26,8,17,33,29,30,2,32,19,12,27,12,12,12,12,17,28,8,17,7,19,17,7,19,7,19,12,27,33,12,27,35,12,17,33,8,8,17,26,28,19,33,18,8,8,8,28,18,19,35,12,17,33,8,12,18,12,12,12,17,28,12,4,12,12,12,35,12,4,19,12,17,7,19,12,18,19,35,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
9841557,8,12,27,12,12,12,17,8,29,8,3,5,32,17,33,12,18,35,12,17,33,8,12,4,12,17,7,19,12,12,12,12,4,33,4,8,8,8,8,4,19,35,12,17,33,12,4,35,35,12,10,4,19,35,18,19,8,17,26,28,19,33,29,30,2,32,19,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
9841558,28,26,8,12,27,12,17,8,1,19,35,12,17,33,12,18,35,12,17,33,12,4,12,12,12,25,12,17,12,4,19,8,33,8,25,26,28,19,35,12,17,33,29,8,3,5,32,17,8,12,4,12,12,12,17,29,33,14,32,17,26,28,19,29,28,36,32,19,12,12,4,19,8,17,26,28,19,33,29,30,2,32,19,35,18,19,35,12,27,8,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
9841559,8,28,8,17,26,28,19,12,27,35,12,17,8,1,19,12,17,8,1,19,35,12,27,33,12,27,35,12,17,33,12,18,12,12,12,17,6,17,8,19,17,26,28,19,26,28,19,12,12,18,19,35,12,17,33,12,18,12,17,8,17,26,28,19,33,29,30,2,32,19,12,12,35,12,18,28,8,19,35,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
9841560,8,28,8,17,26,28,19,8,12,27,35,12,17,33,12,18,35,12,17,33,12,4,12,12,17,8,19,12,12,12,4,7,19,35,12,17,33,29,8,3,3,5,32,4,8,33,17,8,17,26,28,19,28,8,17,8,19,17,8,19,8,19,8,29,8,3,5,32,4,8,17,26,28,19,33,29,30,2,32,19,35,18,19,10,12,27,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [14]:
if CFG.clm:
    display(test_clm.head())
    display(test_clm.tail())
if CFG.enc:
    display(test_enc.head())
    display(test_enc.tail())

Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141
0,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,35,12,35,12,4,12,18,35,12,35,4,8,8,17,8,19,28,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,35,12,35,12,4,12,18,35,12,35,4,8,8,17,8,19,28,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141
1674891,8,28,8,27,8,8,8,17,8,8,33,12,18,35,12,17,33,12,4,35,13,12,25,12,12,12,17,7,19,12,12,4,25,19,35,12,17,33,29,8,3,3,5,32,17,8,8,8,33,26,29,33,14,32,26,29,33,36,32,19,8,17,26,28,19,33,29,30,2,32,19,35,18,19,8,8,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674892,8,28,8,27,8,8,8,17,8,8,33,12,18,35,12,17,33,12,4,35,13,12,25,12,12,12,17,7,19,12,12,4,25,19,35,12,17,33,29,8,3,3,5,32,17,8,8,8,33,26,29,33,14,32,26,29,33,36,32,19,8,17,26,28,19,33,29,30,2,32,19,35,18,19,8,8,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674893,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674894,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674895,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


# Dataset

In [15]:
class EXDataset(Dataset):
    def __init__(
        self,
        train: pd.DataFrame,
        label: pd.DataFrame = pd.DataFrame(),
        is_test: bool = False,
        transform = None
    ):
        self.train = train
        self.label = label
        self.is_test = is_test
        self.transform = transform
        
    def __len__(self):
        # return total num of data
        return len(self.train)
    
    def __getitem__(self, index:int):
        # return data and target assosiated with index
        X = self.train.iloc[index]
        X = self._apply_transform(X)
        
        if self.is_test:
            # y = np.argmax(np.zeros(CFG.n_classes))
            y = torch.tensor([0, 0, 0], dtype=torch.float)
        else:
            # y = np.argmax(self.label.iloc[index].values)
            y = torch.tensor(self.label.iloc[index].values, dtype=torch.float)

        return X, y
    
    def _apply_transform(self, X):
        if self.transform:
            X = self.transform(X)
        return X

# Model

In [16]:
import torch
import torch.nn as nn

# class MambaModel(nn.Module):
#     def __init__(self, 
#                  dim_model=384, # Model dimension d_model (embedding size)
#                  d_state=16, # SSM state expansion factor
#                  d_conv=4, # Local convolution width
#                  expand=2, # Block expansion factor
#                  output=3, # number of classes (or output number simply)
#                  is_test=False,
#                 ):
import torch.nn as nn
import torch

class MambaModel(nn.Module):
    def __init__(self, 
                 dim_model=384, # Model dimension d_model (embedding size)
                 d_state=64, # SSM state expansion factor
                 d_conv=8, # Local convolution width
                 expand=4, # Block expansion factor
                 output=3, # number of classes (or output number simply)
                 dropout_rate=0.1, # Dropout rate
                 is_test=False,
                ):
        super().__init__()
        self.model = Mamba(
            d_model=dim_model,  
            d_state=d_state,  
            d_conv=d_conv,    
            expand=expand,    
        ).to("cuda")
        self.output = nn.Linear(dim_model, output)
        self.dropout = nn.Dropout(dropout_rate)
        self.sigmoid = nn.Sigmoid()
        self.is_test = is_test

    def forward(self, x):
        # Add the length dimension if input has only 2 dimensions
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
            
        x = self.model(x)
        x = self.dropout(x)  # Apply dropout
        x = self.output(x)
        if self.is_test:
            x = self.sigmoid(x)
        x = x.squeeze()
        
        return x



In [17]:
# class MambaModel(nn.Module):
#     def __init__(self, 
#                  dim_model=384, # Model dimension d_model (embedding size)
#                 #  dim_model=142, # Model dimension d_model (embedding size)
#                  d_state=16, # SSM state expansion factor
#                  d_conv=4, # Local convolution width
#                  expand=2, # Block expansion factor
#                  output = 3, # number of classes (or output number simply)
#                 #  is_test=False
#                  ):
#         super().__init__()
#         self.model = Mamba(
#             d_model=dim_model,  
#             d_state=d_state,  
#             d_conv=d_conv,    
#             expand=expand,    
#         ).to("cuda")
#         # mamba pass trought input size as is.
#         self.output = nn.Linear(dim_model, output)
#         self.softmax = nn.Softmax(dim=-1)
#         # self.is_test = is_test

#     def forward(self, x):
#         # Add the length dimension if input has only 2 dimensions
#         if len(x.shape) == 2:
#             x = x.unsqueeze(1)
            
#         x = self.model(x)
#         x = self.output(x)
#         x = x.squeeze()
#         x = torch.pow(x, 2) # Square to prevent minus value
        
#         # if self.is_test:
#             # x = self.softmax(x)
#         return x

##### set seeds

In [18]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

##### move tensors to device

In [19]:
def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

##### transform values to tensor

In [20]:
def to_tensor(x):
    return torch.tensor(x.values, dtype=torch.float32)

# Train

In [21]:
def train_one_fold(CFG,
                   val_fold: int,
                   train: pd.DataFrame,
                   output_path
                   ):
    if CFG.clm:
        feature_columns = [str(i) for i in range(384)]
    if CFG.enc:
        feature_columns = [f'enc{i}' for i in range(142)]

    label_columns = ['bind1', 'bind2', 'bind3']

    set_random_seed(CFG.seed, deterministic=CFG.deterministic)
    device = torch.device(CFG.device)
    train_dataset = EXDataset(train = train[feature_columns][train['fold']!=val_fold].reset_index(drop=True), 
                              label = train[label_columns][train['fold']!=val_fold].reset_index(drop=True), 
                              transform = to_tensor)
    val_dataset = EXDataset(train = train[feature_columns][train['fold']==val_fold].reset_index(drop=True), 
                            label = train[label_columns][train['fold']==val_fold].reset_index(drop=True), 
                            transform = to_tensor)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)

    if CFG.clm:
        model = MambaModel(dim_model=384)
    if CFG.enc:
        model = MambaModel(dim_model=142)
    model.to(device)

    optimizer = optim.AdamW(params=model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer=optimizer, epochs=CFG.max_epoch,
        pct_start=0.0, steps_per_epoch=len(train_loader),
        max_lr=CFG.lr, div_factor=25, final_div_factor=4.0e-01
    )
    # loss_func = nn.CrossEntropyLoss()
    loss_func = nn.BCEWithLogitsLoss()
    loss_func.to(device)
    # loss_func_val = nn.CrossEntropyLoss()
    loss_func_val = nn.BCEWithLogitsLoss()

    use_amp = CFG.enable_amp
    scaler = amp.GradScaler(enabled=use_amp)

    best_val_loss = 1.0e+09
    best_epoch = 0
    train_loss = 0
    val_loss = 0
    
    for epoch in range(1, CFG.max_epoch + 1):
        epoch_start = time()
        model.train()
        for batch in train_loader:
            
            x, t = batch
            # print(x)
            # print(t)
            x = to_device(x, device)
            t = to_device(t, device)
            # sys.exit()
                
            optimizer.zero_grad()
            with amp.autocast(use_amp):
                y = model(x)
                loss = loss_func(y, t)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
            scheduler.step()
            
        train_loss /= len(train_loader)
            
        model.eval()
        for batch in val_loader:
            x, t = batch
            x = to_device(x, device)
            with torch.no_grad(), amp.autocast(use_amp):
                y = model(x)
#                 y = torch.sigmoid(y)
            y = y.detach().cpu().to(torch.float32)
            loss = loss_func_val(y, t)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        
        if val_loss < best_val_loss:
            best_epoch = epoch
            best_val_loss = val_loss
            # print("save model")
            torch.save(model.state_dict(), str(output_path / f'snapshot_epoch_{epoch}.pth'))
        
        elapsed_time = time() - epoch_start
        print(
            f"[epoch {epoch}] train loss: {train_loss: .6f}, val loss: {val_loss: .6f}, elapsed_time: {elapsed_time: .3f}")
        
        if epoch - best_epoch > CFG.es_patience:
            print("Early Stopping!")
            break
            
        train_loss = 0
        val_loss = 0
            
    return val_fold, best_epoch, best_val_loss

Do the above function actually, and save the best model of each epoch.



In [22]:
# 82minute with 10% data
score_list = []
for fold_id in range(CFG.folds):
    output_path = Path(f"fold{fold_id}")
    output_path.mkdir(exist_ok=True)
    print(f"[fold{fold_id}]")
    if CFG.clm:
        score_list.append(train_one_fold(CFG, fold_id, train_clm, output_path))
    if CFG.enc:
        score_list.append(train_one_fold(CFG, fold_id, train_enc, output_path))

[fold0]
[epoch 1] train loss:  0.031685, val loss:  0.028290, elapsed_time:  272.968
[epoch 2] train loss:  0.027712, val loss:  0.027059, elapsed_time:  271.284
[epoch 3] train loss:  0.026359, val loss:  0.026210, elapsed_time:  272.240
[epoch 4] train loss:  0.025349, val loss:  0.025240, elapsed_time:  278.794
[epoch 5] train loss:  0.024337, val loss:  0.024663, elapsed_time:  280.658
[epoch 6] train loss:  0.023358, val loss:  0.024075, elapsed_time:  275.108
[epoch 7] train loss:  0.022461, val loss:  0.023637, elapsed_time:  279.453
[epoch 8] train loss:  0.021703, val loss:  0.023414, elapsed_time:  271.969
[epoch 9] train loss:  0.021188, val loss:  0.023313, elapsed_time:  270.887
[fold1]
[epoch 1] train loss:  0.031921, val loss:  0.028498, elapsed_time:  270.049
[epoch 2] train loss:  0.027807, val loss:  0.026934, elapsed_time:  277.722
[epoch 3] train loss:  0.026424, val loss:  0.026039, elapsed_time:  271.025
[epoch 4] train loss:  0.025362, val loss:  0.025101, elapse

Check the result.

In [23]:
print(score_list)

[(0, 9, 0.023312912175693066), (1, 9, 0.023351748770818307)]


Delete models without best:

In [24]:
# select the best model and delete others
best_log_list = []
for (fold_id, best_epoch, _) in score_list:
    
    # select the best model
    exp_dir_path = Path(f"fold{fold_id}")
    best_model_path = exp_dir_path / f"snapshot_epoch_{best_epoch}.pth"
    # copy to new place
    copy_to = f"./best_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        # delete
        p.unlink()

# Infer

In [25]:
def run_inference_loop(model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch[0], device)
            y = model(x)
            pred_list.append(y.detach().cpu().numpy())
    
    # concatenate to vertical (to df that like long scroll)
    pred_arr = np.vstack(pred_list)
    del pred_list
    return pred_arr

Do inference actually.

In [26]:
def inference(test):
    test_pred_arr = np.zeros((CFG.folds, len(test), CFG.n_classes))
    score_list = []

    for fold_id in range(CFG.folds):
        print(f"\n[fold {fold_id}]")
        device = torch.device(CFG.device)
        
        if CFG.clm:
            feature_columns = [str(i) for i in range(384)]
        if CFG.enc:
            feature_columns = [f'enc{i}' for i in range(142)]
            
        test_dataset = EXDataset(test[feature_columns],
                                 transform = to_tensor,
                                 is_test = True)
    
        # test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, num_workers=4, shuffle=False, drop_last=False)

        # get model
        model_path = f"./best_model_fold{fold_id}.pth"
        if CFG.clm:
            model = MambaModel(dim_model=384, is_test=True)
        if CFG.enc:
            model = MambaModel(dim_model=142, is_test=True)
        model.load_state_dict(torch.load(model_path, map_location=device))

        # inference
        test_pred = run_inference_loop(model, test_loader, device)
        test_pred_arr[fold_id] = test_pred

        del model, test_loader
        torch.cuda.empty_cache()
        gc.collect()
    return test_pred_arr
if CFG.clm:
    test_preds_arr = inference(test_clm)
if CFG.enc:
    test_preds_arr = inference(test_enc)


[fold 0]


  0%|          | 0/52341 [00:00<?, ?it/s]


[fold 1]


  0%|          | 0/52341 [00:00<?, ?it/s]

mean each fold's predict.

In [27]:
test_pred = test_preds_arr.mean(axis=0)
test_pred = pd.DataFrame(test_pred)
# test_pred = pd.concat([test_clm[['id', 'protein_name']], test_pred], axis=1)
# test_pred = pd.concat([test_enc['id'], test_pred], axis=1)

In [28]:
display(test_pred.head(20))
display(test_pred.tail(20))
print(len(test_pred))
test_pred.to_csv('test_pred_raw.csv')


Unnamed: 0,0,1,2
0,0.004066,0.001274,0.000788
1,0.004066,0.001274,0.000788
2,0.004066,0.001274,0.000788
3,0.000783,0.001515,0.000293
4,0.000783,0.001515,0.000293
5,0.000783,0.001515,0.000293
6,0.011155,0.001198,0.000741
7,0.011155,0.001198,0.000741
8,0.011155,0.001198,0.000741
9,0.003612,0.00165,0.001141


Unnamed: 0,0,1,2
1674876,0.001658,0.00822,0.006957
1674877,0.001658,0.00822,0.006957
1674878,0.00049,0.001573,0.000351
1674879,0.00049,0.001573,0.000351
1674880,0.00049,0.001573,0.000351
1674881,0.000435,0.000498,8e-06
1674882,0.000435,0.000498,8e-06
1674883,0.000435,0.000498,8e-06
1674884,0.000114,0.001134,0.000366
1674885,0.000114,0.001134,0.000366


1674896


# Submission

In [29]:
# normalized_test_pred = test_pred.copy()
# for column in test_pred.columns:
#     min_val = test_pred[column].min()
#     max_val = test_pred[column].max()
#     normalized_test_pred[column] = (test_pred[column] - min_val) / (max_val - min_val)


In [30]:
# display(normalized_test_pred.head(20))
# display(normalized_test_pred.tail(20))
# print(len(test_pred))

In [31]:
# tst.head()

In [32]:
# import numpy as np
# import pandas as pd

# # Ensure test_pred is a numpy array
# test_pred = np.array(test_pred)

# # Read the Parquet file into a DataFrame
# tst = pd.read_parquet('/root/Kaggle_NeurIPS2024/data/raw/test.parquet')

# # Add a new column 'binds' initialized to 0
# tst['binds'] = 0

# # Assign predictions for rows where 'protein_name' is 'BRD4'
# brd4_indices = np.where(tst['protein_name'] == 'BRD4')[0]
# tst.loc[tst['protein_name'] == 'BRD4', 'binds'] = test_pred[brd4_indices, 0]

# # Assign predictions for rows where 'protein_name' is 'HSA'
# hsa_indices = np.where(tst['protein_name'] == 'HSA')[0]
# tst.loc[tst['protein_name'] == 'HSA', 'binds'] = test_pred[hsa_indices, 1]

# # Assign predictions for rows where 'protein_name' is 'sEH'
# seh_indices = np.where(tst['protein_name'] == 'sEH')[0]
# tst.loc[tst['protein_name'] == 'sEH', 'binds'] = test_pred[seh_indices, 2]

# # Create a CSV file with 'id' and 'binds' columns
# tst[['id', 'binds']].to_csv('submission.csv', index=False)


In [33]:
# tst.head()

In [34]:
tst = pd.read_parquet('/root/Kaggle_NeurIPS2024/data/raw/test.parquet')

tst['binds'] = 0
test_pred = np.array(test_pred)

tst.loc[tst['protein_name']=='BRD4', 'binds'] = test_pred[(tst['protein_name']=='BRD4').values, 0]
tst.loc[tst['protein_name']=='HSA', 'binds'] = test_pred[(tst['protein_name']=='HSA').values, 1]
tst.loc[tst['protein_name']=='sEH', 'binds'] = test_pred[(tst['protein_name']=='sEH').values, 2]
final_sub = tst[['id', 'binds']]
tst[['id', 'binds']].to_csv('/root/Kaggle_NeurIPS2024/submission_tst_10.csv', index = False)

  tst.loc[tst['protein_name']=='BRD4', 'binds'] = test_pred[(tst['protein_name']=='BRD4').values, 0]


In [35]:
display(tst.head())
display(tst.tail())

Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name,binds
0,295246830,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,BRD4,0.004066
1,295246831,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,HSA,0.001274
2,295246832,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,sEH,0.000788
3,295246833,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,BRD4,0.000783
4,295246834,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,HSA,0.001515


Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name,binds
1674891,296921721,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,COC1CCC(CCN)CC1,COC1CCC(CCNc2nc(Nc3noc4ccc(F)cc34)nc(N[C@@H](C...,HSA,0.000119
1674892,296921722,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,COC1CCC(CCN)CC1,COC1CCC(CCNc2nc(Nc3noc4ccc(F)cc34)nc(N[C@@H](C...,sEH,7.8e-05
1674893,296921723,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,NCc1cccs1,[N-]=[N+]=NCCC[C@H](Nc1nc(NCc2cccs2)nc(Nc2noc3...,BRD4,0.000482
1674894,296921724,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,NCc1cccs1,[N-]=[N+]=NCCC[C@H](Nc1nc(NCc2cccs2)nc(Nc2noc3...,HSA,0.001734
1674895,296921725,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,NCc1cccs1,[N-]=[N+]=NCCC[C@H](Nc1nc(NCc2cccs2)nc(Nc2noc3...,sEH,0.001201


In [36]:
# def extract_values(df):
#     extracted_values = []
#     num_rows = len(df)
#     for i in tqdm(range(num_rows)):
#         column_index = i % 3  # Cycle through columns 0, 1, 2
#         extracted_values.append(df.iloc[i, column_index])
#     return extracted_values

# reshaped_test_pred = extract_values(test_pred)
# reshaped_test_pred = pd.DataFrame({'binds': reshaped_test_pred})

In [37]:
# display(reshaped_test_pred.head(5))
# display(reshaped_test_pred.tail(5))
# print(len(reshaped_test_pred))

In [38]:
# display(reshaped_test_pred.describe())
# display(reshaped_test_pred.head())
# display(reshaped_test_pred.tail())

In [39]:
# reshaped_test_pred.to_csv('/root/Kaggle_NeurIPS2024/10_enc_submission.csv', index=False)



In [40]:
# !pwd

In [41]:
# df = np.sqrt(np.sqrt(reshaped_test_pred))
# df = np.clip(df, None, 1)

# display(df.describe())


In [42]:
# df.to_csv('clm_submission_root_2times.csv', index=False)


In [43]:
# df = np.sqrt(reshaped_test_pred) * 10
# df = np.clip(df, None, 1)

# display(df.describe())


In [44]:
# df.to_csv('clm_submission_sqrt1_times_10.csv', index=False)
