In [None]:
import random
import os
import warnings
from collections import defaultdict
import glob

import numpy as np
from sklearn.metrics import accuracy_score
from pymatreader import read_mat

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm

from tqdm import tqdm
tqdm.pandas()

warnings.simplefilter('ignore')
data_dir = '../../../data/train'

CROP_LEN_ = 250
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_data(data_dir):
    """
    クロスバリデーションのfold分け
    ラベルの取得
    差分時系列の設定
    """
    data_dict = defaultdict(list)
    label_dict = defaultdict(list)

    for subject in range(5):
        for mat_data in glob.glob(f'{data_dir}/subject{subject}/*'):
            data = read_mat(mat_data)
            start_indexes = (data['event']['init_time'] + 0.2)*1000 // 2
            end_indexes = (data['event']['init_time'] + 0.7)*1000 // 2
            labels = data['event']['type']

            for start_index, end_index, label in zip(start_indexes, end_indexes, labels):
                start_index = int(start_index)
                if (subject == 0) or (subject == 3):
                    if 'train1' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train2' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train3' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)
                elif (subject == 1) or (subject == 4):
                    if 'train2' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train3' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train1' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)
                elif subject == 2:
                    if 'train3' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train1' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train2' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)

    
    ch_names = [c.replace(' ', '') for c in data['ch_labels']]
    diff_list = [
        #横方向
        'F3_F4',
        'FCz_FC1', 'FCz_FC2', 'FCz_FC3', 'FCz_FC4', 'FCz_FC5', 'FCz_FC6', 'FC1_FC2', 'FC3_FC4', 'FC5_FC6',
        'Cz_C1', 'Cz_C2', 'Cz_C3', 'Cz_C4', 'Cz_C5', 'Cz_C6', 'C1_C2', 'C3_C4', 'C5_C6',
        'CPz_CP1', 'CPz_CP2', 'CPz_CP3', 'CPz_CP4', 'CPz_CP5', 'CPz_CP6', 'CP1_CP2', 'CP3_CP4', 'CP5_CP6',
        'P3_P4',
        #縦方向
        'Cz_FCz', 'C1_FC1', 'C2_FC2', 'C3_FC3', 'C4_FC4', 'C5_FC5', 'C6_FC6',
        'Cz_CPz', 'C1_CP1', 'C2_CP2', 'C3_CP3', 'C4_CP4', 'C5_CP5', 'C6_CP6',
        'FCz_CPz', 'FC1_CP1', 'FC2_CP2', 'FC3_CP3', 'FC4_CP4', 'FC5_CP5', 'FC6_CP6',
    ]

    use_ch = []
    for item in diff_list:
        ch1 = item.split('_')[0]
        ch2 = item.split('_')[1]
        use_ch.append(ch1)
        use_ch.append(ch2)

    use_ch = list(set(use_ch))
    use_ch_dict = {ch_names[idx]:idx for idx in range(len(ch_names)) if ch_names[idx] in use_ch}
    
    return data_dict, label_dict, diff_list, use_ch_dict


class SkateDataset(Dataset):
    """
    前処理部分
    """
    def __init__(self, data_list, label_list, diff_list, use_ch_dict):
        self.label_list = label_list
        self.diff_list = diff_list
        self.use_ch_dict = use_ch_dict
        self.crop_len = 250

        self.file_list = [item.split('_')[0] for item in data_list]
        self.index_list = [item.split('_')[1] for item in data_list]

        data_dict = {}
        file_name_list = list(set(self.file_list))

        for file_name in file_name_list:
            data = read_mat(file_name)['data']
            data_dict[file_name] = data

        data = np.empty((len(self.file_list), 72, 250))
        for i, (file_name, idx) in tqdm(enumerate(zip(self.file_list, self.index_list)), total=len(self.file_list)):
            idx = int(idx)
            eeg_signal = data_dict[file_name][:, idx:idx+self.crop_len]
            median = np.median(eeg_signal, axis=1).reshape(72, 1)
            q1 = np.percentile(eeg_signal, 25, axis=1)
            q3 = np.percentile(eeg_signal, 75, axis=1)
            iqr = (q3 - q1).reshape(72, 1)
            iqr = np.where(iqr==0, 1, iqr)
            eeg_signal = (eeg_signal - median) / iqr
            data[i] = eeg_signal
        self.data = data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        label = self.label_list[index]
        data = self.data[index]
        
        return data, label
    

class EEG1DTemporal(nn.Module):
    def __init__(self, num_channels=50, num_samples=250):
        super(EEG1DTemporal, self).__init__()
        #self.conv1 = nn.Conv1d(in_channels, 16, kernel_size=1, stride=1, padding=0)
        self.pad = nn.ZeroPad2d((3, 3, 3, 3))
        self.conv2d_freq = nn.Conv2d(1, 125, (1, num_samples//2), padding='same', bias=False)
        self.bn_freq = nn.BatchNorm2d(125)
        self.conv2d_depth = nn.Conv2d(125, 250, (num_channels, 1), groups=125, bias=False, padding='valid')
        self.bn_depth = nn.BatchNorm2d(250)
        self.elu = nn.ELU()

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv2d_freq(x)
        x = self.bn_freq(x)
        x = self.conv2d_depth(x)
        x = self.bn_depth(x)
        x = self.elu(x)
        x = x[:, :, 0, :]
        x = x.unsqueeze(1)
        x = self.pad(x)

        return x
        
    
class EEG2DCNN(nn.Module):
    def __init__(self):
        super(EEG2DCNN, self).__init__()
        self.model_name = 'efficientvit_b1'
        self.model = timm.create_model('efficientvit_b1.r256_in1k', pretrained=False, in_chans=1, num_classes=3) #, drop_path_rate=0.1, efficientnetv2_rw_s.ra2_in1k, efficientvit_b3.r256_in1k

    def forward(self, x):
        self.features = []
        if 'tf_efficientnet' in self.model_name:
            hook = self.model.global_pool.register_forward_hook(self.hook_fn)
        elif 'efficientvit' in self.model_name:
            hook = self.model.head.global_pool.register_forward_hook(self.hook_fn)
        x = self.model(x)
        
        return x, self.features[0]
    
    def hook_fn(self, module, input, output):
        self.features.append(output) 
    

class StakeModel(nn.Module):
    def __init__(self):
        super(StakeModel, self).__init__()
        self.oned_encoder = EEG1DTemporal(num_channels=72)
        self.twod_encoder = EEG2DCNN()
        self.first_dropout = nn.Dropout(0.4)

    def forward(self, x):
        x = self.first_dropout(x)
        x = self.oned_encoder(x)
        x = self.twod_encoder(x)

        return x
    

def get_model():
    model = StakeModel()

    return model

In [5]:
data_dict, label_dict, diff_list, use_ch_dict = get_data(data_dir)

In [None]:
seed_everything(42)
best_alpha_list = []
for FOLD in range(3):
    data = data_dict[FOLD]
    label = label_dict[FOLD]
    dataset = SkateDataset(data, label, diff_list, use_ch_dict)
    dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=False)

    model = get_model()
    load_weights = torch.load(f'model/model{FOLD}.pth', map_location=device)
    model.load_state_dict(load_weights)
    model = model.eval().to(device)

    y_pred_list, feature_list, label_list = [], [], []
    with tqdm(dataloader) as t:
        for i, (X, label) in enumerate(t):
            X = X.to(device, non_blocking=True).float()
            label = label.to(device, non_blocking=True).long()

            with torch.no_grad():
                y_pred, feature = model(X)
            
            y_pred_list.append(nn.functional.softmax(y_pred, dim=1))
            feature_list.append(feature)
            label_list.append(label)

    y_pred_list = torch.vstack(y_pred_list)
    feature_list = torch.vstack(feature_list)
    label_list = torch.hstack(label_list).detach().cpu().numpy()

    nn_pred = y_pred_list.detach().cpu().numpy()
    lgb_pred = np.load(f'output/lgb_oof_fold{FOLD}.npy', allow_pickle=True)

    base_score = 0
    best_alpha = -0.1
    for alpha in np.arange(0, 1.05, 0.05):
        alpha = round(alpha, 3)
        preds_= alpha * nn_pred + (1 - alpha) * lgb_pred
        preds_ = np.argmax(preds_, axis=1)
        accuracy = accuracy_score(label_list, preds_)

        if accuracy > base_score:
            base_score = accuracy
            best_alpha = alpha
    print(f'best score: {best_alpha} -> {base_score:.5f}')
    best_alpha_list.append(best_alpha)

np.save('output/weight', np.array(best_alpha_list))


100%|██████████| 795/795 [00:01<00:00, 653.69it/s]
100%|██████████| 25/25 [00:01<00:00, 23.88it/s]


best score: 0.25 -> 0.88553


100%|██████████| 799/799 [00:01<00:00, 647.28it/s]
100%|██████████| 25/25 [00:00<00:00, 25.38it/s]


best score: 0.05 -> 0.86108


100%|██████████| 797/797 [00:01<00:00, 647.92it/s]
100%|██████████| 25/25 [00:01<00:00, 23.88it/s]

best score: 0.15 -> 0.86700



