In [27]:
from timm.scheduler import CosineLRScheduler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from scipy.signal import resample, butter, filtfilt
import numpy as np
from tqdm import tqdm
from termcolor import cprint
from sklearn.model_selection import KFold
import datetime
import os
import pytz
import gc


import os
import pickle

from src.preprocess import CosineScheduler
from src.preprocess import EarlyStopping
# from src.models import EEGNet

In [28]:
# EEGNetクラスの定義
class EEGNet(nn.Module):
    def __init__(self, num_classes, num_channels, samples):
        super(EEGNet, self).__init__()

        self.firstconv = nn.Sequential(
            nn.Conv2d(1, 16, (1, 51), stride=(1, 1), padding=(0, 25), bias=False),
            nn.BatchNorm2d(16)
        )

        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, (num_channels, 1), stride=(1, 1), groups=16, bias=False),
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.AvgPool2d((1, 4), stride=(1, 4)),
            nn.Dropout(0.25)
        )

        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, (1, 15), stride=(1, 1), padding=(0, 7), bias=False),
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.AvgPool2d((1, 8), stride=(1, 8)),
            nn.Dropout(0.25)
        )

        self.classify = nn.Sequential(
            nn.Linear(32 * (samples // 32), num_classes)
        )

    def forward(self, x):
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)
        x = x.view(x.size(0), -1)
        x = self.classify(x)
        return x

# データの前処理関数
def resample_data(data, target_rate, current_rate):
    num_samples = int(data.shape[2] * target_rate / current_rate)
    resampled_data = resample(data, num_samples, axis=2)
    return resampled_data

def butter_bandpass(lowcut, highcut, fs, order=5):
    b, a = butter(order, [lowcut / (0.5 * fs), highcut / (0.5 * fs)], btype='band', output='ba')
    return b.astype(np.float32), a.astype(np.float32)  # 精度をfloat32に変更

def bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    filtered_data = filtfilt(b, a, data.astype(np.float32), axis=2)  # 精度をfloat32に変更
    return filtered_data

def standardize(data):
    mean = np.mean(data, axis=2, keepdims=True)
    std = np.std(data, axis=2, keepdims=True)
    standardized_data = (data - mean) / std
    return standardized_data

def preprocess_eeg_data(data, target_rate, current_rate, lowcut, highcut, fs):
    print("resample_data")
    data = resample_data(data, target_rate, current_rate)
    gc.collect()
    print("bandpass_filter")
    data = bandpass_filter(data, lowcut, highcut, fs)
    gc.collect()
    print("standardize")
    data = standardize(data)
    return data

def set_lr(lr, optimizer):
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

# cosine scheduler
class CosineScheduler:
    def __init__(self, epochs, lr, warmup_length=5):
        """
        Arguments
        ---------
        epochs : int
            学習のエポック数．
        lr : float
            学習率．
        warmup_length : int
            warmupを適用するエポック数．
        """
        self.epochs = epochs
        self.lr = lr
        self.warmup = warmup_length

    def __call__(self, epoch):
        """
        Arguments
        ---------
        epoch : int
            現在のエポック数．
        """
        progress = (epoch - self.warmup) / (self.epochs - self.warmup)
        progress = np.clip(progress, 0.0, 1.0)
        lr = self.lr * 0.5 * (1. + np.cos(np.pi * progress))

        if self.warmup:
            lr = lr * min(1., (epoch+1) / self.warmup)

        return lr

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_acc_max = np.NINF
        self.delta = delta

    def __call__(self, val_acc, model):

        score = val_acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        if self.verbose:
            print(f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_acc_max = val_acc

In [29]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# dataset_filename = "train_val-1000.pt"
dataset_filename = "train_val.pt"
dataset_filename = "train_X.pt"
# dataset_filename = "combined.pt"
folder_path = 'data/'
# folder_path = 'data/'

In [30]:
# def process_in_batches(X, batch_size, preprocess_function, *args):
#     num_samples = X.shape[0]
#     for start_idx in range(0, num_samples, batch_size):
#         end_idx = min(start_idx + batch_size, num_samples)
#         yield preprocess_function(X[start_idx:end_idx], *args), start_idx, end_idx

In [31]:
# import os
# import datetime
# import pytz
# import torch
# import numpy as np
# from scipy.signal import resample, butter, filtfilt

# # 上記で提供された前処理関数をここに挿入

# def process_in_batches(X, batch_size, preprocess_function, *args):
#     num_samples = X.shape[0]
#     for start_idx in range(0, num_samples, batch_size):
#         end_idx = min(start_idx + batch_size, num_samples)
#         yield preprocess_function(X[start_idx:end_idx], *args), start_idx, end_idx

# # ベースとなるフォルダ名
# base_folder_name = "outputs"

# # 日本のタイムゾーンを取得
# jst = pytz.timezone('Asia/Tokyo')

# # 現在の日付と時間を取得（日本のローカルタイム）
# now = datetime.datetime.now(jst)

# # フォルダ名を生成（例：2023-06-20/23-59）
# save_folder_name = os.path.join(base_folder_name, now.strftime("%Y-%m-%d/%H-%M"))

# # フォルダを作成（存在しない場合）
# os.makedirs(save_folder_name, exist_ok=True)

# # 元データを読み出す
# data = torch.load(folder_path + dataset_filename)
# X, y = data

# batch_size = 256  # メモリに収まるサイズに応じて調整
# preprocessed_X = None

# for preprocessed_batch, start_idx, end_idx in tqdm(process_in_batches(X.numpy(), batch_size, preprocess_eeg_data, target_rate, current_rate, lowcut, highcut, fs), desc="Preprocess"):
#     if preprocessed_X is None:
#         preprocessed_X = preprocessed_batch
#     else:
#         preprocessed_X = np.concatenate([preprocessed_X, preprocessed_batch], axis=0)

# # PyTorchテンソルに変換し、保存
# X_tensor = torch.tensor(preprocessed_X, dtype=torch.float32).unsqueeze(1)  # (samples, 1, channels, timepoints)
# y_tensor = torch.tensor(y, dtype=torch.long)

# # 処理後のデータを保存
# torch.save((X_tensor, y_tensor), os.path.join(save_folder_name, 'preprocessed_data.pt'))

# print(X_tensor.shape)
# print(y_tensor.shape)

torch.Size([82160, 1, 271, 128])
torch.Size([82160])

In [32]:
# パラメータ
target_rate = 128
num_channels = 271  # チャンネル数
num_timepoints = target_rate  # タイムポイント数
fs = 128
current_rate = 281
num_classes = 1854  # クラス数
lr = 0.001
num_epochs = 50
batch_size=128
n_splits = 5

# デバイスの設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [33]:
import optuna

def objective(trial):

    print("Start objective", trial.number)
    

    lowcut = trial.suggest_uniform('lowcut', 0.01, fs / 2 - 1)
    highcut = trial.suggest_uniform('highcut', lowcut + 1, fs / 2)
    
    print(f"lowcut: {lowcut}, highcut: {highcut}")

    # ベースとなるフォルダ名
    base_folder_name = "outputs"

    # 日本のタイムゾーンを取得
    jst = pytz.timezone('Asia/Tokyo')

    # 現在の日付と時間を取得（日本のローカルタイム）
    now = datetime.datetime.now(jst)

    # フォルダ名を生成（例：06-20/23-59）
    save_folder_name = os.path.join(base_folder_name, now.strftime("%Y-%m-%d/%H-%M"))

    # フォルダを作成（存在しない場合）
    os.makedirs(save_folder_name, exist_ok=True)

    # 元データを読み出す
    # data = torch.load(folder_path+dataset_filename)
    # X, y = data
    # del data

    X = torch.load(os.path.join(folder_path, 'train_X.pt'))
    y = torch.load(os.path.join(folder_path, 'train_y.pt'))

    # 前処理を行うコード（ここに適切な前処理を行うコードを書く）
    preprocessed_data = preprocess_eeg_data(X.numpy(), target_rate, current_rate, lowcut, highcut, fs)
    del X

    # PyTorchテンソルに変換
    X = torch.tensor(preprocessed_data, dtype=torch.float32).unsqueeze(1)  # (samples, 1, channels, timepoints)
    y = torch.tensor(y, dtype=torch.long)

    print(X.shape)
    print(y.shape)

    dataset = TensorDataset(X, y)

    accuracy = Accuracy(
        task="multiclass", num_classes=num_classes, top_k=10
    ).to(device)

    max_val_acc = 0
    max_val_acc_list = []


    # KFoldによる交差検証の設定
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=3711)

    # 交差検証のループ
    for fold, (train_index, val_index) in enumerate(kf.split(dataset)):
        print(f"Fold {fold + 1}")

        # トレーニングデータと検証データに分割
        train_data = torch.utils.data.Subset(dataset, train_index)
        val_data = torch.utils.data.Subset(dataset, val_index)
        print(train_data.dataset[0])

        # データローダの定義
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

        # モデルの定義
        model = EEGNet(num_classes=num_classes, num_channels=num_channels, samples=target_rate).to(device)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        # optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.02)
        scheduler = CosineLRScheduler(optimizer, t_initial=100, lr_min=1e-6,
                                    warmup_t=3, warmup_lr_init=1e-6, warmup_prefix=True)
        early_stopping = EarlyStopping(patience=2, verbose=True)

        max_val_acc = 0  # 各Foldの最大検証精度を追跡するために、ループの内部で初期化

        # トレーニングループ
        for epoch in range(num_epochs):
            print(f"Epoch {epoch+1}/{num_epochs}")

            train_loss, train_acc, val_loss, val_acc = [], [], [], []

            model.train()
            for X, y in tqdm(train_loader, desc="Train"):
                X, y = X.to(device), y.to(device)

                y_pred = model(X)

                loss = F.cross_entropy(y_pred, y)
                train_loss.append(loss.item())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                acc = accuracy(y_pred, y)
                train_acc.append(acc.item())

            model.eval()
            for X, y in tqdm(val_loader, desc="Validation"):
                X, y = X.to(device), y.to(device)

                with torch.no_grad():
                    y_pred = model(X)

                val_loss.append(F.cross_entropy(y_pred, y).item())
                val_acc.append(accuracy(y_pred, y).item())

            scheduler.step(np.mean(val_loss))  # 学習率を調整

            print(f"Epoch {epoch+1}/{num_epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")

            # 各Foldの最初のエポックでモデルを保存
            if epoch == 0:
                torch.save(model.state_dict(), os.path.join(save_folder_name, f"model_best_fold{fold+1}.pt"))
                print(f"Initial model for Fold {fold+1} saved.")

            val_acc_mean = np.mean(val_acc)
            if val_acc_mean > max_val_acc:
                cprint("New best.", "cyan")
                # Fold番号をファイル名に含める
                torch.save(model.state_dict(), os.path.join(save_folder_name, f"model_best_fold{fold+1}.pt"))
                max_val_acc = np.mean(val_acc)


            early_stopping(val_acc_mean, model)
            if early_stopping.early_stop:
                print("Early stopping. Max Acc = ", max_val_acc)
                break
        
        max_val_acc_list.append(max_val_acc)
        print("Max Acc = ", max_val_acc)

    mean_acc = np.mean(max_val_acc_list)
    print("Acc mean = ", mean_acc)

    return mean_acc

In [34]:
study = optuna.create_study(
    load_if_exists=True,
    direction="maximize",
    storage="sqlite:///db.sqlite3",  # Specify the storage URL here.
    study_name="bandpass_filter_study_5cv",
)
# study.enqueue_trial({'lowcut': 0.5, 'highcut': 50.0}, user_attrs={"memo": "baseline"})
# study.enqueue_trial({'lowcut': 0.1, 'highcut': 50.0}, user_attrs={"memo": "baseline+"})
# study.enqueue_trial({'lowcut': 0.5, 'highcut': 60.0}, user_attrs={"memo": "baseline+"})
# study.enqueue_trial({'lowcut': 0.1, 'highcut': 60.0}, user_attrs={"memo": "baseline+"})
# study.enqueue_trial({'lowcut': 0.5, 'highcut': 63.0}, user_attrs={"memo": "no filter"})
# study.enqueue_trial({'lowcut': 0.01, 'highcut': 60.0}, user_attrs={"memo": "no filter"})
# study.enqueue_trial({'lowcut': 0.01, 'highcut': 63.0}, user_attrs={"memo": "no filter"})
study.optimize(objective, n_trials=200)
print(study.best_trial)

[I 2024-07-17 22:17:55,156] Using an existing study with name 'bandpass_filter_study_5cv' instead of creating a new one.
  lowcut = trial.suggest_uniform('lowcut', 0.01, fs / 2 - 1)
  highcut = trial.suggest_uniform('highcut', lowcut + 1, fs / 2)


Start objective 3
lowcut: 58.907602980607194, highcut: 62.28420506503475
resample_data
bandpass_filter
standardize


  x = um.multiply(x, x, out=x)
  y = torch.tensor(y, dtype=torch.long)


torch.Size([65728, 1, 271, 128])
torch.Size([65728])
Fold 1
(tensor([[[0., -0., 0.,  ..., -0., -0., -0.],
         [0., -0., 0.,  ..., -0., -0., -0.],
         [0., -0., -0.,  ..., -0., -0., -0.],
         ...,
         [-0., 0., -0.,  ..., 0., 0., 0.],
         [-0., 0., -0.,  ..., 0., 0., 0.],
         [-0., 0., -0.,  ..., 0., 0., 0.]]]), tensor(1759))
Epoch 1/50


Train:  72%|███████▏  | 295/411 [00:28<00:11, 10.41it/s]
[W 2024-07-17 22:22:06,176] Trial 3 failed with parameters: {'lowcut': 58.907602980607194, 'highcut': 62.28420506503475} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "C:\Users\tshigata\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\optuna\study\_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\tshigata\AppData\Local\Temp\ipykernel_17332\2222125907.py", line 103, in objective
    acc = accuracy(y_pred, y)
  File "C:\Users\tshigata\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\tshigata\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\l

KeyboardInterrupt: 