In [None]:
import random
import os
import warnings
from collections import defaultdict
import glob

import numpy as np
from sklearn.metrics import accuracy_score
from pymatreader import read_mat

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
tqdm.pandas()

warnings.simplefilter('ignore')
data_dir = '../../../data/train'

CROP_LEN_ = 250
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_data(data_dir):
    """
    クロスバリデーションのfold分け
    ラベルの取得
    差分時系列の設定
    """
    data_dict = defaultdict(list)
    label_dict = defaultdict(list)

    for subject in range(5):
        for mat_data in glob.glob(f'{data_dir}/subject{subject}/*'):
            data = read_mat(mat_data)
            start_indexes = (data['event']['init_time'] + 0.2)*1000 // 2
            end_indexes = (data['event']['init_time'] + 0.7)*1000 // 2
            labels = data['event']['type']

            for start_index, end_index, label in zip(start_indexes, end_indexes, labels):
                start_index = int(start_index)
                if (subject == 0) or (subject == 3):
                    if 'train1' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train2' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train3' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)
                elif (subject == 1) or (subject == 4):
                    if 'train2' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train3' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train1' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)
                elif subject == 2:
                    if 'train3' in mat_data:
                        data_dict[0].append(f'{mat_data}_{start_index}')
                        label_dict[0].append(int(str(int(label))[-1])-1)
                    elif 'train1' in mat_data:
                        data_dict[1].append(f'{mat_data}_{start_index}')
                        label_dict[1].append(int(str(int(label))[-1])-1)
                    elif 'train2' in mat_data:
                        data_dict[2].append(f'{mat_data}_{start_index}')
                        label_dict[2].append(int(str(int(label))[-1])-1)

    
    ch_names = [c.replace(' ', '') for c in data['ch_labels']]
    diff_list = [
        #横方向
        'F3_F4',
        'FCz_FC1', 'FCz_FC2', 'FCz_FC3', 'FCz_FC4', 'FCz_FC5', 'FCz_FC6', 'FC1_FC2', 'FC3_FC4', 'FC5_FC6',
        'Cz_C1', 'Cz_C2', 'Cz_C3', 'Cz_C4', 'Cz_C5', 'Cz_C6', 'C1_C2', 'C3_C4', 'C5_C6',
        'CPz_CP1', 'CPz_CP2', 'CPz_CP3', 'CPz_CP4', 'CPz_CP5', 'CPz_CP6', 'CP1_CP2', 'CP3_CP4', 'CP5_CP6',
        'P3_P4',
        #縦方向
        'Cz_FCz', 'C1_FC1', 'C2_FC2', 'C3_FC3', 'C4_FC4', 'C5_FC5', 'C6_FC6',
        'Cz_CPz', 'C1_CP1', 'C2_CP2', 'C3_CP3', 'C4_CP4', 'C5_CP5', 'C6_CP6',
        'FCz_CPz', 'FC1_CP1', 'FC2_CP2', 'FC3_CP3', 'FC4_CP4', 'FC5_CP5', 'FC6_CP6',
    ]

    use_ch = []
    for item in diff_list:
        ch1 = item.split('_')[0]
        ch2 = item.split('_')[1]
        use_ch.append(ch1)
        use_ch.append(ch2)

    use_ch = list(set(use_ch))
    use_ch_dict = {ch_names[idx]:idx for idx in range(len(ch_names)) if ch_names[idx] in use_ch}
    
    return data_dict, label_dict, diff_list, use_ch_dict


class SkateDataset(Dataset):
    """
    前処理部分
    """
    def __init__(self, fold, data_list, label_list, diff_list, use_ch_dict):
        self.label_list = label_list
        self.diff_list = diff_list
        self.use_ch_dict = use_ch_dict
        self.crop_len = 250

        self.file_list = [item.split('_')[0] for item in data_list]
        self.index_list = [item.split('_')[1] for item in data_list]

        self.iqr = np.load(f'../../../data/scaler/iqr{fold}.npy', allow_pickle=True)
        self.median = np.load(f'../../../data/scaler/median{fold}.npy', allow_pickle=True)
        self.iqr = self.iqr.reshape(72, 1)
        self.median = self.median.reshape(72, 1)

        self.data_dict = {}
        file_name_list = list(set(self.file_list))

        for file_name in tqdm(file_name_list):
            data = read_mat(file_name)['data']
            data = (data - self.median) / self.iqr
            eeg_signal = []
            for ch_num, channels in enumerate(self.diff_list):
                ch1_name = channels.split('_')[0]
                ch2_name = channels.split('_')[1]
                ch1 = data[self.use_ch_dict[ch1_name]]
                ch2 = data[self.use_ch_dict[ch2_name]]
                signal = ch1 - ch2
                eeg_signal.append(signal)

            eeg_signal = np.stack(eeg_signal)
            self.data_dict[file_name] = eeg_signal

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        file_name = self.file_list[index]
        idx = int(self.index_list[index])
        label = self.label_list[index]

        data = self.data_dict[file_name]
        eeg_signal = data[:, idx:idx+self.crop_len]
        
        return eeg_signal, label
    

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dropout=0.2):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride=1, padding=padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        
        self.residual = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn_residual = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        residual = self.residual(x)
        residual = self.bn_residual(residual)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.gelu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual
        out = self.gelu(out)
        out = self.pool(out)
        out = self.dropout(out)
        
        return out
    

class ConvEncoder(nn.Module):
    def __init__(self, in_channels):
        super(ConvEncoder, self).__init__()
        self.conv_layers = nn.Sequential(
            ResidualBlock(in_channels, 64),
            ResidualBlock(64, 128),
            ResidualBlock(128, 256),
            ResidualBlock(256, 512),
        )

    def forward(self, x):
        return self.conv_layers(x)
    

class StakeModel(nn.Module):
    def __init__(self, encoder=None, in_channels=50):
        super(StakeModel, self).__init__()
        self.first_dropout = nn.Dropout(0.4)
        self.encoder = ConvEncoder(in_channels) #encoder
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256, bias=False),
            nn.Dropout(0.4),
            nn.Linear(256, 3, bias=False),
        )

    def forward(self, x):
        x = self.first_dropout(x)
        x = self.encoder(x)
        x = self.pool(x)
        bs, channels, _ = x.shape
        x_ = x.reshape(bs, channels)
        x = self.classifier(x_)

        return x, x_


def get_model():
    model = StakeModel()

    return model

In [3]:
data_dict, label_dict, diff_list, use_ch_dict = get_data(data_dir)

In [None]:
seed_everything(42)
best_alpha_list = []
for FOLD in range(3):
    data = data_dict[FOLD]
    label = label_dict[FOLD]
    dataset = SkateDataset(FOLD, data, label, diff_list, use_ch_dict)
    dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=False)

    model = get_model()
    load_weights = torch.load(f'model/model{FOLD}.pth', map_location=device)
    model.load_state_dict(load_weights)
    model = model.eval().to(device)

    y_pred_list, feature_list, label_list = [], [], []
    with tqdm(dataloader) as t:
        for i, (X, label) in enumerate(t):
            X = X.to(device, non_blocking=True).float()
            label = label.to(device, non_blocking=True).long()

            with torch.no_grad():
                y_pred, feature = model(X)
            
            y_pred_list.append(nn.functional.softmax(y_pred, dim=1))
            feature_list.append(feature)
            label_list.append(label)

    y_pred_list = torch.vstack(y_pred_list)
    feature_list = torch.vstack(feature_list)
    label_list = torch.hstack(label_list).detach().cpu().numpy()

    nn_pred = y_pred_list.detach().cpu().numpy()
    lgb_pred = np.load(f'output/lgb_oof_fold{FOLD}.npy', allow_pickle=True)

    base_score = 0
    best_alpha = -0.1
    for alpha in np.arange(0, 1.05, 0.05):
        alpha = round(alpha, 3)
        preds_= alpha * nn_pred + (1 - alpha) * lgb_pred
        preds_ = np.argmax(preds_, axis=1)
        accuracy = accuracy_score(label_list, preds_)

        if accuracy > base_score:
            base_score = accuracy
            best_alpha = alpha
    print(f'best score: {best_alpha} -> {base_score:.5f}')
    best_alpha_list.append(best_alpha)

np.save('output/weight', np.array(best_alpha_list))
