In [None]:
import random
import os
import pickle
import warnings
from collections import defaultdict
import glob

import pandas as pd
import numpy as np
import optuna
import lightgbm as lgb
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from pymatreader import read_mat

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm

from tqdm import tqdm
tqdm.pandas()

from cwt import CWT

warnings.simplefilter('ignore')
data_dir = '../../../data/train'

CROP_LEN_ = 250
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

os.mkdir('output')

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_data(data_dir):
    """
    クロスバリデーションのfold分け
    ラベルの取得
    差分時系列の設定
    """
    data_dict = defaultdict(list)
    label_dict = defaultdict(list)

    for subject in range(5):
        for mat_data in glob.glob(f'{data_dir}/subject{subject}/*'):
            data = read_mat(mat_data)
            start_indexes = (data['event']['init_time'] + 0.2)*1000 // 2
            end_indexes = (data['event']['init_time'] + 0.7)*1000 // 2
            labels = data['event']['type']

            for start_index, end_index, label in zip(start_indexes, end_indexes, labels):
                start_index = int(start_index)
                if (subject == 0) or (subject == 3):
                    if 'train1' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train2' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train3' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)
                elif (subject == 1) or (subject == 4):
                    if 'train2' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train3' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train1' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)
                elif subject == 2:
                    if 'train3' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train1' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train2' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)

    
    ch_names = [c.replace(' ', '') for c in data['ch_labels']]
    diff_list = [
        #横方向
        'F3_F4',
        'FCz_FC1', 'FCz_FC2', 'FCz_FC3', 'FCz_FC4', 'FCz_FC5', 'FCz_FC6', 'FC1_FC2', 'FC3_FC4', 'FC5_FC6',
        'Cz_C1', 'Cz_C2', 'Cz_C3', 'Cz_C4', 'Cz_C5', 'Cz_C6', 'C1_C2', 'C3_C4', 'C5_C6',
        'CPz_CP1', 'CPz_CP2', 'CPz_CP3', 'CPz_CP4', 'CPz_CP5', 'CPz_CP6', 'CP1_CP2', 'CP3_CP4', 'CP5_CP6',
        'P3_P4',
        #縦方向
        'Cz_FCz', 'C1_FC1', 'C2_FC2', 'C3_FC3', 'C4_FC4', 'C5_FC5', 'C6_FC6',
        'Cz_CPz', 'C1_CP1', 'C2_CP2', 'C3_CP3', 'C4_CP4', 'C5_CP5', 'C6_CP6',
        'FCz_CPz', 'FC1_CP1', 'FC2_CP2', 'FC3_CP3', 'FC4_CP4', 'FC5_CP5', 'FC6_CP6',
    ]

    use_ch = []
    for item in diff_list:
        ch1 = item.split('_')[0]
        ch2 = item.split('_')[1]
        use_ch.append(ch1)
        use_ch.append(ch2)

    use_ch = list(set(use_ch))
    use_ch_dict = {ch_names[idx]:idx for idx in range(len(ch_names)) if ch_names[idx] in use_ch}
    
    return data_dict, label_dict, diff_list, use_ch_dict


class SkateDataset(Dataset):
    """
    前処理部分
    """
    def __init__(self, fold, data_list, label, diff_list, use_ch_dict):
        self.label = label
        self.diff_list = diff_list
        self.use_ch_dict = use_ch_dict
        self.crop_len = 250

        self.file_list = [item.split('_')[0] for item in data_list]
        self.index_list = [item.split('_')[1] for item in data_list]

        self.iqr = np.load(f'../../../data/scaler/iqr{fold}.npy', allow_pickle=True)
        self.median = np.load(f'../../../data/scaler/median{fold}.npy', allow_pickle=True)
        self.iqr = self.iqr.reshape(72, 1)
        self.median = self.median.reshape(72, 1)

        self.data_dict = {}
        file_name_list = list(set(self.file_list))

        for file_name in tqdm(file_name_list):
            data = read_mat(file_name)['data']
            data = (data - self.median) / self.iqr
            eeg_signal = []
            for ch_num, channels in enumerate(self.diff_list):
                ch1_name = channels.split('_')[0]
                ch2_name = channels.split('_')[1]
                ch1 = data[self.use_ch_dict[ch1_name]]
                ch2 = data[self.use_ch_dict[ch2_name]]
                signal = ch1 - ch2
                eeg_signal.append(signal)

            eeg_signal = np.stack(eeg_signal)
            self.data_dict[file_name] = eeg_signal

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        file_name = self.file_list[index]
        idx = int(self.index_list[index])
        label = self.label[index]

        data = self.data_dict[file_name]
        eeg_signal = data[:, idx:idx+self.crop_len]
        
        return eeg_signal, label
    
  
class StakeModel(nn.Module):
    def __init__(self):
        super(StakeModel, self).__init__()
        self.model_name = 'efficientvit_b0.r224_in1k'
        self.model = timm.create_model(self.model_name, pretrained=True, in_chans=1, num_classes=3, drop_rate=0.2)#, drop_path_rate=0.1
        
    def forward(self, x):
        self.features = []
        if 'tf_efficientnet' in self.model_name:
            hook = self.model.global_pool.register_forward_hook(self.hook_fn)
        elif 'efficientvit' in self.model_name:
            hook = self.model.head.global_pool.register_forward_hook(self.hook_fn)
        x = self.model(x)
        
        return x, self.features[0]
    
    def hook_fn(self, module, input, output):
        self.features.append(output)


def get_model():
    model = StakeModel()

    return model
    

class MyPipeline(nn.Module):
    def __init__(self, samplerate, lowcut, highcut, wavelet_width, n_scales, stride):
        super().__init__()
        self.cwt = CWT(wavelet_width=wavelet_width, fs=samplerate, lower_freq=lowcut, upper_freq=highcut, n_scales=n_scales, stride=stride, border_crop=1)
        
    def forward(self, x):
        x = self.cwt(x)
        x = torch.cat([x[:, i, :, :] for i in range(50)], dim=1).unsqueeze(1)
        x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        return x

In [4]:
data_dict, label_dict, diff_list, use_ch_dict = get_data(data_dir)

In [None]:
#学習部分
seed_everything(42)
pipeline = MyPipeline(samplerate=500, lowcut=1, highcut=180, wavelet_width=1, n_scales=16, stride=1).to(device)
for FOLD in range(3):
    data = data_dict[FOLD]
    label = label_dict[FOLD]
    dataset = SkateDataset(FOLD, data, label, diff_list, use_ch_dict)
    dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=False)

    model = get_model()
    load_weights = torch.load(f'model/model{FOLD}.pth', map_location=device)
    model.load_state_dict(load_weights)
    model = model.eval().to(device)

    feature_list, label_list = [], []
    with tqdm(dataloader) as t:
        for i, (X, label) in enumerate(t):
            X = X.to(device, non_blocking=True).float()
            label = label.to(device, non_blocking=True).long()

            with torch.no_grad():
                X = pipeline(X)
                y_pred, feature = model(X)
            
            feature_list.append(feature)
            label_list.append(label)

    feature_list = torch.vstack(feature_list)
    label_list = torch.hstack(label_list)

    X = feature_list.detach().cpu().numpy()
    y = label_list.detach().cpu().numpy()

    data = pd.DataFrame(data=X)
    data['label'] = y

    skf = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)

    for fold, (train_index, valid_index) in enumerate(skf.split(X, y)):
        data.loc[valid_index, 'fold'] = fold

    target = data['label']
    use_columns = [i for i in range(X.shape[1])]

    def objective(trial):
        params = {
            'verbose': -1,
            'random_state': 42,
            'objective': 'multiclass',
            'metric': 'multi_logloss',
            'boosting_type': 'gbdt',
            'class_weight': {0: 2.0, 1: 2.0, 2: 1.0},
            'max_depth': trial.suggest_int('max_depth', 2, 4),
            'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.1),
            'num_leaves': trial.suggest_int('num_leaves', 3, 255),
            'min_child_samples': trial.suggest_int('min_child_samples', 3, 150),
            'colsample_bytree': trial.suggest_float('colsample_bytree', 0.1, 1),
            'subsample_freq': trial.suggest_int('subsample_freq', 0, 10),
            'subsample': trial.suggest_float('subsample', 0.1, 1),
            'reg_alpha': trial.suggest_loguniform('reg_alpha', 1e-9, 10.0),
            'reg_lambda': trial.suggest_loguniform('reg_lambda', 1e-9, 10.0),
        }
        
        score = list()
        
        for fold in range(3):
            print(f'CV fold: {fold}')
            train_idx = data.loc[data.fold != fold].index
            valid_idx = data.loc[data.fold == fold].index
            X_train, y_train = data[use_columns].iloc[train_idx], target[train_idx]
            X_valid, y_valid = data[use_columns].iloc[valid_idx], target[valid_idx]
            
            model = lgb.LGBMClassifier(
                **params, n_estimators=1000, early_stopping_round=20, force_row_wise=True, deterministic=True)
            callbacks = [lgb.early_stopping(stopping_rounds=20, verbose=-1)]
            model.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], callbacks=callbacks)
            y_pred = model.predict(X_valid)
            accuracy = accuracy_score(y_valid, y_pred)
            print(accuracy)

            score.append(model.best_score_['valid_0']['multi_logloss']) #
        
        return np.mean(score)

    study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler()) #seed=config.random_state
    study.optimize(objective, n_trials=100)
    print('Best Log_Score', study.best_value)
    print('Best_params', study.best_params)

    params = {
        'verbose': -1,
        'random_state': 42,
        'objective': 'multiclass',
        'boosting_type': 'gbdt',
        'metric': 'multi_logloss',
        'class_weight': {0: 2.0, 1: 2.0, 2: 1.0},
        'verbosity': 0,
        }

    params.update(study.best_params)
    np.save(f'output/lgb_best_param_fold{FOLD}', params)

    oof = np.zeros((len(data), 3))
    metric_evaluation = list()

    for fold in range(3):
        train_idx = data.loc[data.fold != fold].index
        valid_idx = data.loc[data.fold == fold].index
        X_train, y_train = data[use_columns].iloc[train_idx], target[train_idx]
        X_valid, y_valid = data[use_columns].iloc[valid_idx], target[valid_idx]
        
        model = lgb.LGBMClassifier(
            **params, n_estimators=1000, early_stopping_round=20, force_row_wise=True, deterministic=True)
        
        callbacks = [lgb.early_stopping(stopping_rounds=20, verbose=-1),
                    lgb.log_evaluation(period=300, show_stdv=False)]
        
        model.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], callbacks=callbacks)
        metric_evaluation.append(model.best_score_['valid_0']['multi_logloss'])
        oof[valid_idx] = model.predict_proba(X_valid, num_iteration=model.best_iteration_)

        pickle.dump(model, open(f'lgb_model/lgb_fold{FOLD}_{fold}', 'wb'))
        
    print(f'{np.mean(metric_evaluation)}({np.std(metric_evaluation)})')
    np.save(f'output/lgb_oof_fold{FOLD}', oof)
