# Leash Bio

- RDkit記述子とfingerprintを入れる
- 藤田さんのアーキテクチャ

## ref
- https://www.kaggle.com/code/yyyu54/pytorch-version-belka-1dcnn-starter-with-all-data
- https://www.kaggle.com/code/ahmedelfazouan/belka-1dcnn-starter-with-all-data/notebook

In [1]:
exp_no = '018'
DEBUG = True
data_ratio = 1/5

In [2]:
# !pip install rdkit
# !pip install mordred

In [3]:
import gc
import os
import pickle
import random
import joblib
import pandas as pd
# import polars as pd
from tqdm import tqdm

import numpy as np
from sklearn.metrics import average_precision_score as APS
from sklearn.model_selection import StratifiedKFold

import torch
from torch.utils.data import TensorDataset, Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor

from pytorch_lightning import LightningModule
from pytorch_lightning import LightningDataModule, Trainer
# seed_everything
from pytorch_lightning.callbacks import (
    ModelCheckpoint, 
    EarlyStopping,
    ModelCheckpoint,
    RichModelSummary,
    RichProgressBar,
)
# 標準化
from sklearn.preprocessing import StandardScaler
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from funcs.utils import find_latest_ckpt_path, del_old_ckpt_path
from funcs.calc_descriptor import (calc_ecfp4_descriptors,
                                   calc_avalonfp_descriptors,
                                   calc_fcfp4_descriptors, 
                                   calc_rdkit_descriptors)
from funcs.tokenize import tokenize_smiles

import warnings
warnings.simplefilter('ignore')

In [4]:
import os
from pathlib import Path

def is_kaggle_kernel():
    return os.path.exists('/kaggle/working')

if is_kaggle_kernel():

    BASE_DIR = Path("/kaggle")
    DATA_DIR = BASE_DIR / "input"
    OUTPUT_DIR = BASE_DIR / "working"
    print('on kaggle notebook')

else:
    BASE_DIR = Path(os.getcwd()) / './../'
    DATA_DIR = BASE_DIR / "data"
    OUTPUT_DIR = BASE_DIR / f"output/exp{exp_no}"
    
# set device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():    
    device = "cuda"
else:
    device = "cpu"
    
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
print('Using', torch.cuda.device_count(), 'GPU(s)')
print('pytorch:', torch.__version__)

Using 0 GPU(s)
pytorch: 2.3.0


In [5]:
class config:
    SEED = 2024
    
    PREPROCESS = False
    EPOCHS = 20 #20
    BATCH_SIZE = 4096
    NUM_WORKERS = 16
    
    LR = 1e-3
    WEIGHT_DECAY = 1e-6
    MIXED_PRECISION = True
    
    NUM_FOLDS = 5    
    USE_NUM_FOLD = 1
    
class paths:    
    DATA_DIR = DATA_DIR
    OUTPUT_DIR = OUTPUT_DIR
    MODEL_WEIGHTS_DIR = OUTPUT_DIR / f"bio-models-exp{exp_no}"
    
    SHRUNKEN_DATA_DIR = DATA_DIR / "shrunken-train-set"

    TRAIN_PATH = SHRUNKEN_DATA_DIR / "train_fold.parquet"
    TEST_PATH = SHRUNKEN_DATA_DIR / "test_fold.parquet"
    
    OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

In [6]:
print('fix seed')

def my_seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
# seed_everything(config.SEED, workers=True)
my_seed_everything(config.SEED)

fix seed


# **Loda Data**

In [7]:
bb_cols = ['buildingblock1_smiles', 'buildingblock2_smiles','buildingblock3_smiles', 'bb1_scaffold_idx', 'fold']
TARGETS = ['binds_BRD4', 'binds_HSA','binds_sEH']

df_train = pd.read_parquet(paths.TRAIN_PATH, columns=bb_cols + TARGETS)
    
if DEBUG:
    df_train = df_train.sample(100000).reset_index(drop=True)
else:
    len_train = int(len(df_train)*data_ratio)
    df_train = df_train.sample(len_train).reset_index(drop=True)

In [8]:
# building block smiles
# NOTE: trainとtestのindexとsmilesは一致していないっぽい
with open(paths.SHRUNKEN_DATA_DIR / 'train_dicts/BBs_dict_reverse_1.p', 'rb') as file:
    train_dicts_bb1 = pickle.load(file)
with open(paths.SHRUNKEN_DATA_DIR / 'train_dicts/BBs_dict_reverse_2.p', 'rb') as file:
    train_dicts_bb2 = pickle.load(file)
with open(paths.SHRUNKEN_DATA_DIR / 'train_dicts/BBs_dict_reverse_3.p', 'rb') as file:
    train_dicts_bb3 = pickle.load(file)
with open(paths.SHRUNKEN_DATA_DIR / 'train_dicts/BBs_scaffold_dict_reverse_1.p', 'rb') as file:
    train_dicts_bb1_scf = pickle.load(file)

with open(paths.SHRUNKEN_DATA_DIR / 'test_dicts/BBs_dict_reverse_1_test.p', 'rb') as file:
    test_dicts_bb1 = pickle.load(file)
with open(paths.SHRUNKEN_DATA_DIR / 'test_dicts/BBs_dict_reverse_2_test.p', 'rb') as file:
    test_dicts_bb2 = pickle.load(file)
with open(paths.SHRUNKEN_DATA_DIR / 'test_dicts/BBs_dict_reverse_3_test.p', 'rb') as file:
    test_dicts_bb3= pickle.load(file)
with open(paths.SHRUNKEN_DATA_DIR / 'test_dicts/BBs_scaffold_dict_reverse_1.p', 'rb') as file:
    test_dicts_bb1_scf = pickle.load(file)
    
# bb1のidxをscaffoldのidxに変換するdict
with open(paths.SHRUNKEN_DATA_DIR / 'test_dicts/BBs_idx_to_scaffold_idx_dict_1.p', mode='rb') as file:
    test_bb1idx2scfidx= pickle.load(file)

test_dicts_bb1_reverse = {val:key for key, val in test_dicts_bb1.items()}
test_dicts_bb2_reverse = {val:key for key, val in test_dicts_bb2.items()}
test_dicts_bb3_reverse = {val:key for key, val in test_dicts_bb3.items()}
test_dicts_bb1_scaffold_reverse = {val:key for key, val in test_dicts_bb1_scf.items()}

In [9]:
df_test = pd.read_parquet(paths.DATA_DIR / 'test.parquet')
df_test.drop(['molecule_smiles'], axis=1, inplace=True)

df_test['buildingblock1_smiles'] = df_test['buildingblock1_smiles'].map(test_dicts_bb1_reverse)
df_test['buildingblock2_smiles'] = df_test['buildingblock2_smiles'].map(test_dicts_bb2_reverse)
df_test['buildingblock3_smiles'] = df_test['buildingblock3_smiles'].map(test_dicts_bb3_reverse)

df_test['bb1_scaffold_idx'] = df_test['buildingblock1_smiles'].map(test_bb1idx2scfidx)

# **Make Features**

In [10]:
# rdkit descriptors
df_train_bb1_rdkit = calc_rdkit_descriptors(train_dicts_bb1)
df_train_bb2_rdkit = calc_rdkit_descriptors(train_dicts_bb2)
df_train_bb3_rdkit = calc_rdkit_descriptors(train_dicts_bb3)
df_train_bb1_scf_rdkit = calc_rdkit_descriptors(train_dicts_bb1_scf)

df_test_bb1_rdkit = calc_rdkit_descriptors(test_dicts_bb1)
df_test_bb2_rdkit = calc_rdkit_descriptors(test_dicts_bb2)
df_test_bb3_rdkit = calc_rdkit_descriptors(test_dicts_bb3)
df_test_bb1_scf_rdkit = calc_rdkit_descriptors(test_dicts_bb1_scf)


df_train_bb1_ecfp4 = calc_avalonfp_descriptors(train_dicts_bb1)
df_train_bb2_ecfp4 = calc_avalonfp_descriptors(train_dicts_bb2)
df_train_bb3_ecfp4 = calc_avalonfp_descriptors(train_dicts_bb3)
df_train_bb1_scf_ecfp4 = calc_avalonfp_descriptors(train_dicts_bb1_scf)

df_test_bb1_ecfp4 = calc_avalonfp_descriptors(test_dicts_bb1)
df_test_bb2_ecfp4 = calc_avalonfp_descriptors(test_dicts_bb2)
df_test_bb3_ecfp4 = calc_avalonfp_descriptors(test_dicts_bb3)
df_test_bb1_scf_ecfp4 = calc_avalonfp_descriptors(test_dicts_bb1_scf)

# input_len = df_train_bb1.shape[1]
# print('input_len:', input_len)  

In [11]:
def remove_std0(df_list):
    # 標準偏差が0の列を削除
    df_all = pd.concat(df_list,axis=0)
    df_all.drop_duplicates(inplace=True)
    df_all = df_all.loc[:, df_all.std() != 0]
    
    standardized_df_list = []
    for df_temp in df_list:
        df_temp = df_temp.loc[:, df_all.columns]
        standardized_df_list.append(df_temp)
        
    return standardized_df_list


def standardization(df_list):
    # 複数のdfをまとめて標準化
    df_all = pd.concat(df_list,axis=0)
    df_all.drop_duplicates(inplace=True)
    df_all.replace([np.inf, -np.inf], np.nan, inplace=True)

    # standard scaling
    scaler = StandardScaler()
    df_all_array = scaler.fit_transform(df_all)
    
    # 全てnanの列を検出しておく
    nan_columns = np.all(np.isnan(df_all_array), axis=0)
    del_cols = df_all.columns[nan_columns]

    standardized_df_list = []
    for df_temp in df_list:
        df_temp = df_temp.loc[:, df_all.columns]
        df_temp_std = pd.DataFrame(scaler.transform(df_temp), 
                                index=df_temp.index, 
                                columns=df_temp.columns)
        df_temp_std.drop(columns=del_cols, inplace=True)
        standardized_df_list.append(df_temp_std)
        
    standardized_df_list = remove_std0(standardized_df_list)
        
    return standardized_df_list


In [12]:
# Rdkit記述子をまとめて標準化
df_list_rdkit = [
            df_train_bb1_rdkit,df_train_bb2_rdkit, df_train_bb3_rdkit, df_train_bb1_scf_rdkit,
            df_test_bb1_rdkit,df_test_bb2_rdkit,df_test_bb3_rdkit,df_test_bb1_scf_rdkit
            ]
df_train_bb1_rdkit,df_train_bb2_rdkit, df_train_bb3_rdkit, df_train_bb1_scf_rdkit, \
    df_test_bb1_rdkit,df_test_bb2_rdkit,df_test_bb3_rdkit,df_test_bb1_scf_rdkit \
        = standardization(df_list_rdkit)
        
# ECFP4記述子をまとめて標準化
df_list_ecfp4 = [
            df_train_bb1_ecfp4,df_train_bb2_ecfp4, df_train_bb3_ecfp4, df_train_bb1_scf_ecfp4,
            df_test_bb1_ecfp4,df_test_bb2_ecfp4,df_test_bb3_ecfp4,df_test_bb1_scf_ecfp4
            ]
df_train_bb1_ecfp4,df_train_bb2_ecfp4, df_train_bb3_ecfp4, df_train_bb1_scf_ecfp4,\
            df_test_bb1_ecfp4,df_test_bb2_ecfp4,df_test_bb3_ecfp4,df_test_bb1_scf_ecfp4 \
                = remove_std0(df_list_ecfp4)


In [13]:
len_rdkit = df_train_bb1_rdkit.shape[1]
len_ecfp4 = df_train_bb1_ecfp4.shape[1]
print(len_rdkit, len_ecfp4)

187 512


# **Dataset & DataModule**

In [71]:
class BioDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        df: pd.DataFrame,
        df_bb1_1: pd.DataFrame,
        df_bb2_1: pd.DataFrame,
        df_bb3_1: pd.DataFrame,
        df_bb1_scf_1: pd.DataFrame,
        df_bb1_2: pd.DataFrame,
        df_bb2_2: pd.DataFrame,
        df_bb3_2: pd.DataFrame,
        df_bb1_scf_2: pd.DataFrame,
        mode = 'train'
    ):
        super().__init__()
        
        assert mode in ['train', 'valid', 'test']
        self.mode = mode
        
        meta_cols = ["buildingblock1_smiles", "buildingblock2_smiles", "buildingblock3_smiles", "bb1_scaffold_idx"]
        if (self.mode == 'train') or (self.mode == 'valid'):
            meta_cols += TARGETS
            
        self.df = df[meta_cols].values
        self.bb1_1 = df_bb1_1.values
        self.bb2_1 = df_bb2_1.values
        self.bb3_1 = df_bb3_1.values
        self.bb1_scf_1 = df_bb1_scf_1.values
        self.bb1_2 = df_bb1_2.values
        self.bb2_2 = df_bb2_2.values
        self.bb3_2 = df_bb3_2.values
        self.bb1_scf_2 = df_bb1_scf_2.values

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        
        row = self.df[index, :]

        bb1_1 = self.bb1_1[row[0], :]
        bb2_1 = self.bb2_1[row[1], :]
        bb3_1 = self.bb3_1[row[2], :]
        bb1_scf_1 = self.bb1_scf_1[row[3], :]
        bb1_2 = self.bb1_2[row[0], :]
        bb2_2 = self.bb2_2[row[1], :]
        bb3_2 = self.bb3_2[row[2], :]
        bb1_scf_2 = self.bb1_scf_2[row[3], :]
        
        if self.mode == 'train':
            bb2_1, bb3_1, bb2_2, bb3_2 = self.augment(bb2_1, bb3_1, bb2_2, bb3_2)
        
        X = np.concatenate([bb1_1, bb2_1, bb3_1, bb1_scf_1,
                             bb1_2, bb2_2, bb3_2, bb1_scf_2])
        
        if (self.mode == 'train') or (self.mode == 'valid'):
            y = row[-3:]
        else:
            y = np.zeros(3)
        
        output = {
            'X': torch.tensor(X, dtype=torch.float32),
            'y': torch.tensor(y, dtype=torch.float16)
        }        
        return output
    
    def augment(self, bb2_1, bb3_1, bb2_2, bb3_2):
        """0.5の確率でx2とx3を入れ替えるaugmentation"""
        if np.random.rand() < 0.5:
            bb2_1, bb3_1 = bb3_1, bb2_1
            bb2_2, bb3_2 = bb3_2, bb2_2
        return bb2_1, bb3_1, bb2_2, bb3_2

In [72]:
# Check Dataset
if DEBUG:
    dataset = BioDataset(df_train, 
                         df_train_bb1_rdkit,df_train_bb2_rdkit, df_train_bb3_rdkit, df_train_bb1_scf_rdkit,
                         df_train_bb1_ecfp4,df_train_bb2_ecfp4, df_train_bb3_ecfp4, df_train_bb1_scf_ecfp4,
                         mode='valid')
    X = dataset[0]['X']
    y = dataset[0]['y']
    print(X.shape)
    print(y.shape)

torch.Size([9568])
torch.Size([3])


In [73]:
# lightning data module
class BioDataModule(LightningDataModule):
    def __init__(self, df_train, fold_id):
        super().__init__()
        
        self.train_df = df_train[df_train['fold'] != fold_id]
        self.valid_df = df_train[df_train['fold'] == fold_id]

    def train_dataloader(self):
        train_dataset = BioDataset(self.train_df, 
                                   df_train_bb1_rdkit,df_train_bb2_rdkit, df_train_bb3_rdkit, df_train_bb1_scf_rdkit,
                                    df_train_bb1_ecfp4,df_train_bb2_ecfp4, df_train_bb3_ecfp4, df_train_bb1_scf_ecfp4,
                                   mode='train')
        train_dataloader = torch.utils.data.DataLoader(
                                train_dataset,
                                batch_size=config.BATCH_SIZE,
                                shuffle=True,
                                num_workers=config.NUM_WORKERS,
                                pin_memory=True,
                                persistent_workers=True,
                                drop_last=True,
                            )
        return train_dataloader

    def val_dataloader(self):
        valid_dataset = BioDataset(self.valid_df, 
                                     df_train_bb1_rdkit,df_train_bb2_rdkit, df_train_bb3_rdkit, df_train_bb1_scf_rdkit,
                                    df_train_bb1_ecfp4,df_train_bb2_ecfp4, df_train_bb3_ecfp4, df_train_bb1_scf_ecfp4, 
                                   mode='valid')
        valid_dataloader = torch.utils.data.DataLoader(
                                            valid_dataset,
                                            batch_size=config.BATCH_SIZE,
                                            shuffle=False,
                                            num_workers=config.NUM_WORKERS,
                                            pin_memory=True,
                                            persistent_workers=True,
                                            drop_last=False,
                                        )
        return valid_dataloader

# **Model**

In [74]:
class BioModel(nn.Module):
    def __init__(self, 
                 input_len1,
                 input_len2,
                 output_dim=3):
        super(BioModel, self).__init__()
        
        self.input_len1 = input_len1
        self.input_len2 = input_len2
        self.output_dim = output_dim
        
        # それぞれの記述子のFC（desc1）
        self.feature_extractor_bb1_desc1 = nn.Sequential(
            nn.Linear(self.input_len1, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        self.feature_extractor_bb23_desc1 = nn.Sequential(
            nn.Linear(self.input_len1, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        self.feature_extractor_bb1scf_desc1 = nn.Sequential(
            nn.Linear(self.input_len1, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
         # それぞれの記述子のFC（desc2）
        self.feature_extractor_bb1_desc2 = nn.Sequential(
            nn.Linear(self.input_len2, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        self.feature_extractor_bb23_desc2 = nn.Sequential(
            nn.Linear(self.input_len2, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        self.feature_extractor_bb1scf_desc2 = nn.Sequential(
            nn.Linear(self.input_len2, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        
        # それぞれのBBのFC
        self.feature_extractor_bb1 = nn.Sequential(
            nn.Linear(128*2, 324),
            nn.BatchNorm1d(324),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        self.feature_extractor_bb23 = nn.Sequential(
            nn.Linear(128*2, 324),
            nn.BatchNorm1d(324),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        self.feature_extractor_bbscf = nn.Sequential(
            nn.Linear(128*2, 324),
            nn.BatchNorm1d(324),
            nn.Dropout(0.1),
            nn.ReLU(),
        )
        # head
        self.head = nn.Sequential(
            nn.Linear(324*4, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(512, 3),
        )  

    def forward(self, x):
        desc1 = x[:, :self.input_len1*4]
        desc2 = x[:, self.input_len1*4:]
        
        # 各BB, 各記述子のFC
        bb1_1 = self.feature_extractor_bb1_desc1(desc1[:, :self.input_len1])
        bb2_1 = self.feature_extractor_bb23_desc1(desc1[:, self.input_len1:self.input_len1*2])
        bb3_1 = self.feature_extractor_bb23_desc1(desc1[:, self.input_len1*2:self.input_len1*3])
        bb1_scf_1 = self.feature_extractor_bb1scf_desc1(desc1[:, self.input_len1*3:])
        bb1_2 = self.feature_extractor_bb1_desc2(desc2[:, :self.input_len2])
        bb2_2 = self.feature_extractor_bb23_desc2(desc2[:, self.input_len2:self.input_len2*2])
        bb3_2 = self.feature_extractor_bb23_desc2(desc2[:, self.input_len2*2:self.input_len2*3])
        bb1_scf_2 = self.feature_extractor_bb1scf_desc2(desc2[:, self.input_len2*3:])
        
        bb1 = torch.cat([bb1_1, bb1_2], dim=1)
        bb2 = torch.cat([bb2_1, bb2_2], dim=1)
        bb3 = torch.cat([bb3_1, bb3_2], dim=1)
        bbscf = torch.cat([bb1_scf_1, bb1_scf_2], dim=1)   
        
        # 各BBのFC
        bb1 = self.feature_extractor_bb1(bb1)
        bb2 = self.feature_extractor_bb23(bb2)
        bb3 = self.feature_extractor_bb23(bb3)
        bbscf = self.feature_extractor_bb23(bbscf)
        
        X = torch.cat([bb1, bb2, bb3, bbscf], dim=1)
        
        output = self.head(X)
        
        return output

In [75]:
# check model
if DEBUG:
    dummy_model = BioModel(input_len1=len_rdkit, input_len2=len_ecfp4)
    total_params = sum(p.numel() for p in dummy_model.parameters())
    print(f"Total number of parameters: {total_params}")

    dummy_input = torch.rand((64, (len_rdkit+len_ecfp4)*4), dtype=torch.float32)
    output = dummy_model(dummy_input)
    print(output.shape)
    # print(output)

Total number of parameters: 3030119
torch.Size([64, 3])


# **Lightning Module**

In [76]:
def calc_score(y_preds, y_true):
    score_BRD4 = APS(y_true[:,0], y_preds[:,0])
    score_HSA = APS(y_true[:,1], y_preds[:,1])
    score_sEH = APS(y_true[:,2], y_preds[:,2])
    score = (score_BRD4 + score_HSA + score_sEH) / 3
    
    return score_BRD4, score_HSA, score_sEH, score

In [77]:
class BioModule(LightningModule):
    def __init__(self):
        
        super(BioModule, self).__init__()
       
        self.model = BioModel(input_len1=len_rdkit, input_len2=len_ecfp4)
        self.validation_step_outputs = []
        self.loss_func = nn.BCEWithLogitsLoss()
        
    def forward(self, X):
        pred = self.model(X)
        return pred
    
    def configure_optimizers(self):
        
        # == define optimizer ==
        model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )
        # == define learning rate scheduler ==
        lr_scheduler = CosineAnnealingWarmRestarts(
            model_optimizer,
            T_0=config.EPOCHS,
            T_mult=1,
            eta_min=1e-6,
            last_epoch=-1
        )
        return {
            'optimizer': model_optimizer,
            'lr_scheduler': {
                'scheduler': lr_scheduler,
                'interval': 'epoch',
                'monitor': 'valid_loss_epoch',
                'frequency': 1
            }
        }
        
    def training_step(self, batch, batch_idx):
        
        X, y = batch.pop('X'), batch.pop('y')
        logits = self(X)
        train_loss = self.loss_func(logits, y)
        
        self.log('train_loss', train_loss,  on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=X.size(0))
        
        return train_loss

    def validation_step(self, batch, batch_idx):
        
        X, y = batch.pop('X'), batch.pop('y')
        logits = self(X)
        preds = torch.sigmoid(logits)
        
        valid_loss = self.loss_func(logits, y)
        
        self.log('valid_loss', valid_loss, on_step=True, on_epoch=False, prog_bar=True, logger=True, batch_size=X.size(0))
        
        self.validation_step_outputs.append({"valid_loss":valid_loss, "preds":preds, "targets":y})
        
        return valid_loss

    
    def train_dataloader(self):
        return self._train_dataloader

    def validation_dataloader(self):
        return self._validation_dataloader
    
    def calc_score(self, y_preds, y_true):
        return calc_score(y_preds, y_true)

    
    def on_validation_epoch_end(self):
        
        outputs = self.validation_step_outputs
        
        # 各iterationごとのlossを平均
        avg_loss = torch.stack([x['valid_loss'] for x in outputs]).mean()
        self.log("valid_loss_epoch", avg_loss, prog_bar=True, logger=True)
        
        # scoreを計算
        y_preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()
        y_true = torch.cat([x['targets'] for x in outputs]).detach().cpu().numpy()
        
        score = self.calc_score(y_preds, y_true)[-1]
        self.log("valid_score", score, prog_bar=True, logger=True)
        
        self.validation_step_outputs.clear()
        
        return {'valid_loss_epoch': avg_loss, "valid_score":score}

# Train & Inference

In [29]:
def predict_in_batches(model, df, 
                       df_bb1_1, df_bb2_1, df_bb3_1, df_bb1_scf_1, 
                       df_bb1_2, df_bb2_2, df_bb3_2, df_bb1_scf_2, 
                       mode):
    
    model.to(device)
    model.eval()
    
    dataset = BioDataset(df, 
                          df_bb1_1, df_bb2_1, df_bb3_1, df_bb1_scf_1, 
                       df_bb1_2, df_bb2_2, df_bb3_2, df_bb1_scf_2, 
                         mode=mode)
    dataloader = torch.utils.data.DataLoader(
                                        dataset,
                                        batch_size=config.BATCH_SIZE,
                                        shuffle=False,
                                        num_workers=config.NUM_WORKERS,
                                        pin_memory=True,
                                        persistent_workers=True,
                                        drop_last=False,
                                    )

    all_preds = []
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['X'].to(device)
            logits = model(inputs)
            preds = torch.sigmoid(logits)
            all_preds.append(preds.cpu().numpy())
    
    return np.concatenate(all_preds, axis=0)

In [30]:
def run_training(fold_id, df, infer=False):
    print(f"======== Running training for fold {fold_id} =============")
    
    # == init data module and model ==
    model = BioModule()
    datamodule = BioDataModule(df, fold_id)
    
    # == init callback ==
    checkpoint_callback = ModelCheckpoint(
                                        monitor='valid_score',
#                                             monitor='valid_loss_epoch',
                                          dirpath=paths.MODEL_WEIGHTS_DIR,
                                          save_top_k=1,
                                          save_last=False,
                                          save_weights_only=True,
                                          filename=f"fold_{fold_id}",
                                          mode='max'
                                          )
    early_stop_callback = EarlyStopping(
        monitor='valid_score',
#         monitor="valid_loss_epoch", 
        mode="max", 
        patience=5,
        verbose=True
        )
    callbacks_to_use = [checkpoint_callback,
                        early_stop_callback,
                        RichModelSummary(),
                        RichProgressBar(),
                       ]

    # == init trainer ==
    trainer = Trainer(
        max_epochs=config.EPOCHS,
        callbacks=callbacks_to_use,
        accelerator=device,
        devices=-1,  # 全ての利用可能なGPUを使用
        deterministic=False,
        precision='16-mixed' if config.MIXED_PRECISION else 32,
        logger=TensorBoardLogger('lightning_logs', name=f'exp{exp_no}_fold{fold_id}'),
    )
    
    if not infer:
        # == Training ==
        trainer.fit(model, datamodule=datamodule)
        weights = torch.load(checkpoint_callback.best_model_path)['state_dict']
    else:
        ckpt_path = find_latest_ckpt_path(fold_id, paths.MODEL_WEIGHTS_DIR) 
        weights = torch.load(ckpt_path)['state_dict']
        
    model.load_state_dict(weights)
    
    valid_df = datamodule.valid_df
    
    preds_oof = predict_in_batches(model, valid_df, 
                                     df_train_bb1_rdkit,df_train_bb2_rdkit, df_train_bb3_rdkit, df_train_bb1_scf_rdkit,
                                    df_train_bb1_ecfp4,df_train_bb2_ecfp4, df_train_bb3_ecfp4, df_train_bb1_scf_ecfp4,
                                   mode='valid')
    y_oof = valid_df[TARGETS].values
    
    score_BRD4, score_HSA, score_sEH, score = calc_score(preds_oof, y_oof)
    
    valid_df[[f'{target}_pred' for target in TARGETS]] = preds_oof
    
    print(f'fold:{fold_id} | CV score = {score}')
    
    df_test_temp = df_test.drop(['id'], axis=1)
    preds_test = predict_in_batches(model, df_test_temp, 
                                      df_test_bb1_rdkit,df_test_bb2_rdkit, df_test_bb3_rdkit, df_test_bb1_scf_rdkit,
                                    df_test_bb1_ecfp4,df_test_bb2_ecfp4, df_test_bb3_ecfp4, df_test_bb1_scf_ecfp4,
                                    mode='test')
    
    del model, datamodule, trainer, preds_oof, y_oof
    gc.collect()
    
    score_dict = {
        'BRD4':score_BRD4,
        "HSA":score_HSA,
        "sEH":score_sEH,
        "all":score
    }
    
    return preds_test, score_dict, valid_df

In [31]:
# training
# torch.set_float32_matmul_precision('high')

# tokenizerの warning対策
os.environ["TOKENIZERS_PARALLELISM"] = "false"

all_preds = []
score_list = []
score_list_BRD4 = []
score_list_HSA = []
score_list_sEH = []

def save_list_by_text(score_list, filename):
    # ファイルに書き込み
    score_list_txt = [str(loss) for loss in score_list]
    with open(paths.OUTPUT_DIR / f'{filename}.txt', 'w') as file:
        file.write(', '.join(score_list_txt))
    

for fold_id in range(config.NUM_FOLDS):
    
    preds_test, score_dict, df_oof = run_training(fold_id, df_train, infer=False)
    
    # save score
    score_list_BRD4.append(score_dict['BRD4'])
    score_list_HSA.append(score_dict['HSA'])
    score_list_sEH.append(score_dict['sEH'])
    score_list.append(score_dict['all'])
    
    save_list_by_text(score_list, 'cv_all')
    save_list_by_text(score_list_BRD4, 'cv_BRD4')
    save_list_by_text(score_list_HSA, 'cv_HSA')
    save_list_by_text(score_list_sEH, 'cv_sEH')
    
    # save preds（foldごと）
    all_preds.append(preds_test) 
    
    df_oof.to_parquet(paths.OUTPUT_DIR / f"oof_fold_{fold_id}.parquet")
    
    del df_oof
    gc.collect()
    

df_oof_all = pd.DataFrame()
for fold_id in range(config.NUM_FOLDS):
    df_temp = pd.read_parquet(paths.OUTPUT_DIR / f"oof_fold_{fold_id}.parquet")
    df_oof_all = pd.concat([df_oof_all, df_temp], axis=0)

df_oof_all.to_parquet(paths.OUTPUT_DIR / f"oof_all.parquet")

RuntimeError: Pin memory thread exited unexpectedly

# **Submission**

In [None]:
preds = np.mean(all_preds, 0)

df_test['binds'] = 0
df_test.loc[df_test['protein_name']=='BRD4', 'binds'] = preds[df_test['protein_name']=='BRD4', 0]
df_test.loc[df_test['protein_name']=='HSA', 'binds'] = preds[df_test['protein_name']=='HSA', 1]
df_test.loc[df_test['protein_name']=='sEH', 'binds'] = preds[df_test['protein_name']=='sEH', 2]
df_test[['id', 'binds']].to_csv(paths.OUTPUT_DIR / f'submission_fold{fold_id}.csv', index = False)

In [None]:
# 古いckpt pathを削除
for fold in range(0, 5): 
    del_old_ckpt_path(fold, paths.MODEL_WEIGHTS_DIR)
    
    oof_path = paths.OUTPUT_DIR / f'oof_fold_{fold}.parquet'
    oof_path.unlink()