In [117]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torchsummary import summary
import os, time
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split


In [118]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cuda:0'

In [119]:
brainAdress = ['Fp1','AF7','AF3','F1','F3','F5','F7','FT7','FC5',
'FC3','FC1','C1','C3','C5','T7','TP7','CP5','CP3','CP1','P1','P3',
'P5','P7','P9','PO7','PO3','O1','Iz','Oz','POz','Pz','CPz','Fpz','Fp2',
'AF8','AF4','AFz','Fz','F2','F4','F6','F8','FT8','FC6','FC4','FC2',
'FCz','Cz','C2','C4','C6','T8','TP8','CP6','CP4','CP2','P2','P4',
'P6','P8','P10','PO8','PO4','O2']

numpy dataの読み込み

In [120]:
day1_data = np.load('numpy_data/subject3_data1.npy')
day1_label = np.load('numpy_data/subject3_label1.npy')
day2_data = np.load('numpy_data/subject3_data2.npy')
day2_label = np.load('numpy_data/subject3_label2.npy')

データの結合

In [121]:
brain_data = np.vstack([day1_data,day2_data])
label_data = np.hstack([day1_label,day2_label])
# info_motorbrain = ['FC3','FC1','C1','C3','C5','CP3','CP1','CPz','FC4','FC2',
# 'FCz','Cz','C2','C4','C6','CP4','CP2',]
# info_motor = mne.create_info(ch_names=info_motorbrain, ch_types="eeg", sfreq=1024)
# info_motor.set_montage('standard_1020')

In [122]:
#運動野領域のデータ　sampling rate は1024->100 Hz タスクは4秒間
motor_brainAdress = [9,10,11,12,13,17,18,31,44,45,46,47,48,49,50,54,55]
data_numpy_task  = brain_data[:,:,100*6:100*10]

In [123]:
print("brain_data",np.shape(data_numpy_task))
print("label",np.shape(label_data))

brain_data (240, 64, 400)
label (240,)


In [124]:
#train testの分割
X_train, X_test, Y_train, Y_test = train_test_split(data_numpy_task,label_data, test_size=0.2,random_state=42)
print(np.shape(X_train))

(192, 64, 400)


In [125]:
x_train = torch.from_numpy(np.expand_dims(X_train, axis=1))
X_test = torch.from_numpy(np.expand_dims(X_test, axis=1))
label = torch.tensor(Y_train)
label2 = torch.tensor(Y_test)



In [126]:
print(x_train.shape, label.shape)
print(X_test.shape, label2.shape)

torch.Size([192, 1, 64, 400]) torch.Size([192])
torch.Size([48, 1, 64, 400]) torch.Size([48])


In [127]:
# Datasetを作成
Dataset = torch.utils.data.TensorDataset(x_train, label)
# Datasetを作成
tast_data = torch.utils.data.TensorDataset(X_test, label2)

In [128]:
BATCH_SIZE = 32
Learning_Rate = 0.001
EPOCHS = 500

In [129]:
trainloader = torch.utils.data.DataLoader(Dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 1)
testloader = torch.utils.data.DataLoader(tast_data, batch_size = BATCH_SIZE, shuffle = False, num_workers = 1)


In [130]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()

        self.F1 = 8
        self.F2 = 16
        self.D = 2
        
        # Conv2d(in,out,kernel,stride,padding,bias)
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, self.F1, (1, 64), padding=(0, 32), bias=False),
            nn.BatchNorm2d(self.F1)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(self.F1, self.D*self.F1, (17, 1), groups=self.F1, bias=False),
            nn.BatchNorm2d(self.D*self.F1),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(0.5)
        )
        self.Conv3 = nn.Sequential(
            nn.Conv2d(self.D*self.F1, self.D*self.F1, (1, 16), padding=(0, 8), groups=self.D*self.F1, bias=False),
            nn.Conv2d(self.D*self.F1, self.F2, (1, 1), bias=False),
            nn.BatchNorm2d(self.F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(0.5)
        )
        
        self.classifier = nn.Linear(16*17, 1, bias=True)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.Conv3(x)
        
        x = x.view(-1, 16*17)
        x = self.classifier(x)
        return x

In [131]:
class Model(object):
    def __init__(self, model=None, lr=0.001):
        super(Model, self).__init__()
        self.model = model
        self.losses = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        
    def fit(self, trainloader=None, validloader=None, epochs=1, monitor=None, only_print_finish_ep_num=False):
        doValid = False if validloader == None else True
        pre_ck_point = [float("inf"), 0.0, float("inf"), 0.0, 0] # loss, acc, val_loss, val_acc, epoch
        history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []}
        for ep in range(1, epochs + 1):
            proc_start = time.time() # timer start
            if (not (ep % 10)) or (ep == 1):
                if not only_print_finish_ep_num:
                    print(f"Epoch {ep}/{epochs}")
            self.model.train()       # Train mode
            step = 1                 # Restart step
            for x_batch, y_batch in trainloader:
                x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device)
                pred = self.model(x_batch)
                loss = self.losses(pred, y_batch)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if (not (ep % 10)) or (ep == 1):
                    pbar = int(step * 30 / len(trainloader))
                    if not only_print_finish_ep_num:
                        print("\r{}/{} [{}{}]".format(
                            step, len(trainloader), ">" * pbar, " " * (30 - pbar)), 
                            end="")
                step += 1
            loss, acc = self.evaluate(trainloader)   # Loss & Accuracy
            val_loss, val_acc = self.evaluate(validloader) if doValid else (0, 0)   # if have validation dataset, evaluate validation
            history["loss"] = np.append(history["loss"], loss)
            history["acc"] = np.append(history["acc"], acc)
            history["val_loss"] = np.append(history["val_loss"], val_loss)
            history["val_acc"] = np.append(history["val_acc"], val_acc)
            # Update checkpoint
            if self.__updateCheckpoint(monitor, pre_ck_point, [loss, acc, val_loss, val_acc, ep]):
                save_file_name = f"checkpoint_model_ep-{ep}.pt"
                self.save(save_file_name)
                pre_ck_point = [loss, acc, val_loss, val_acc, ep]
                history['lastest_model_path'] = save_file_name
                
            if only_print_finish_ep_num and (ep % 50 == 0):
                print(f"{ep} ", end=" ")
        return history
    def evaluate(self, dataloader):
        total, acc = 0, 0
        self.model.eval()           # Eval mode
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device)
            pred = self.model(x_batch)
            loss = self.losses(pred, y_batch).item()
            total += y_batch.shape[0]     # Number of data
            acc += (torch.sum(pred.argmax(dim=1)==y_batch)).item()     # Sum the prediction that's correct
        acc /= total     # Accuracy = correct prediction / number of data
        return (loss, acc)
    
    def predict(self, dataset):
        dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
        prediction = []
        truth = []
        self.model.eval()
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device)
            pred = self.model(x_batch).cpu()
            prediction = np.append(prediction, pred.argmax(dim=1).numpy())
            truth = np.append(truth, y_batch.cpu().numpy())            
        return prediction, truth
    
    def save(self, filepath):
        torch.save(self.model, filepath)
        #@classmethod
    def load(cls, filepath):
        return cls(torch.load(filepath))
    
    def __updateCheckpoint(self, monitor, pre_ck_point, evaluation):
        if type(monitor) is int:
            return True if evaluation[4] % monitor == 0 else False
        elif type(monitor) is list:
            for _ in monitor:
                if not _ in ["loss", "acc", "val_loss", "val_acc"]:
                    raise Exception(f"\"{_}\" is not a valid monitor condition.")
                elif _ == "loss" and pre_ck_point[0] <= evaluation[0]:
                    return False # present epoch loss > history loss
                elif _ == "acc" and pre_ck_point[1] >= evaluation[1]:
                    return False # present epoch acc <= history acc
                elif _ == "val_loss" and pre_ck_point[2] <= evaluation[2]:
                    return False # present epoch val_loss > history val_loss
                elif _ == "val_acc" and pre_ck_point[3] >= evaluation[3]:
                    return False # present epoch val_acc < history val_acc        
        return True

In [132]:
eegnet = EEGNet().to(device)
summary(eegnet, (1, 64, 400))
model = Model(eegnet, lr=Learning_Rate)
history = model.fit(trainloader=trainloader, validloader=testloader, epochs=EPOCHS, monitor=["acc", "val_acc"])

Epoch 1/500


RuntimeError: shape '[-1, 272]' is invalid for input of size 294912

In [None]:
def base_path(path: str):
    return os.path.join(BASE_DIR, path)

def plot_acc_and_loss(history, figsize=(10,4), base_save_path=None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    if base_save_path:
        st = fig.suptitle(base_save_path, fontsize="x-large")
    
    ax1.title.set_text("Acc")
    ax1.set_xlabel("Epochs")
    l1 = ax1.plot(history["acc"], color="red", label='train')
    l2 = ax1.plot(history["val_acc"], color="blue", label='test')
    
    ax2.title.set_text("Loss")
    ax2.set_ylabel("Epochs")
    l3 = ax2.plot(history["loss"], color="red", label='train')
    l4 = ax2.plot(history["val_loss"], color="blue", label='test')

    ax1.legend(loc="upper right")
    ax2.legend(loc="upper right")

    if base_save_path:
        plt.savefig(base_path(base_save_path))
    plt.show()

In [None]:
BASE_DIR = './' #@param {type:"string"}
plot_acc_and_loss(history=history, base_save_path="part_1.png")