In [None]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.abspath('__file__')), '..'))
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import spectrogram

from sklearn.linear_model import LinearRegression
from train_code.generator.raw_generator import transpose_raw1d,transpose_raw2d
from npy2trials import load_data
from sklearn.preprocessing import StandardScaler
from scipy.signal import butter, filtfilt, periodogram
from predict_by_another import predict,preprocess
from analyse1 import analyse1
ch_list = ['FC5','FC1','FC2','FC6','C3','C1','Cz','C2','C4','CP5','CP1','CP2','CP6']
ch_size = len(ch_list)
block_size = 750
step = 250
fs = 500

In [None]:
data_path = "C:/Users/gomar/Dropbox/輸送/test2.npy"
d1_model_path = "C:/Users/gomar/Dropbox/輸送/models/dec1/model-85.h5"
d2_model_path = "C:/Users/gomar/Dropbox/輸送/models/dec2/model-23.h5"

In [None]:
full_data = np.load(data_path)
stim_data,predictclass_list,trueclass_list = load_data(full_data,fs)
minp = len(predictclass_list[0]) #TODO:15d1s2
for item in predictclass_list:
    if minp > len(item):
        minp = len(item)
predictclass_list = [item[:minp] for item in predictclass_list]
stim_data.shape

# データプレビュー

In [None]:
left_data = []
right_data = []

for i in range(stim_data.shape[0]):
    data = stim_data[i,:ch_size,:]
    y = trueclass_list[i]
    if y==1:
        left_data.append(preprocess(data,fs))
    else:
        right_data.append(preprocess(data,fs))
left_data = np.array(left_data)
right_data = np.array(right_data)

def plot_spec(key:str,data,ft=None):
    row = 5
    col = 3
    fig = plt.figure(figsize=(20, 12))
    plt.subplots_adjust(wspace=0.4, hspace=0.8)
    if ft is None:
        specs = [[] for _ in range(data.shape[0])]
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                frequencies, times, spectrogram_data = spectrogram(data[i,j,:], fs)
                specs[i].append(spectrogram_data)
        specs = np.array(specs)
        specs = np.sum(specs,axis=0)
    else:
        frequencies, times = ft
        specs = data
    # スペクトログラムの可視化
    for i in range(ch_size):
        plt.subplot(row, col, i+1)
        p = 10 * np.log10(specs[i,:,:]) 
        #p = np.log(np.abs(p))
        plt.pcolormesh(times, frequencies, p,vmax=16)
        plt.colorbar()#label='Power/Frequency (dB/Hz)')
        plt.title(ch_list[i])
        #plt.clim(-50, 5) 
        plt.ylim(0, 50)
        plt.grid()
    fig.suptitle(key + ' Spectrogram')
    plt.show()
    return frequencies, times,specs
frequencies, times,sl = plot_spec("Left",left_data)
_,_,sr = plot_spec("Right",right_data)

row = 5
col = 3
fig = plt.figure(figsize=(10, 6))
plt.subplots_adjust(wspace=0.4, hspace=0.6)
specs = sr - sl
# スペクトログラムの可視化
for i in range(ch_size):
    plt.subplot(row, col, i+1)
    plt.pcolormesh(times, frequencies, specs[i,:,:], shading='gouraud')
    plt.colorbar()#label='Power/Frequency (dB/Hz)')
    #plt.clim(-50, 5) 
    plt.ylim(0, 25)
    plt.grid()
fig.suptitle("Diff Spectrogram")
plt.show()

# 評価

In [None]:
data_metrics,data_detailed_metrics,mn_predictclass_list = analyse1(trueclass_list,predictclass_list)

In [None]:
def plot_lg(x,y,color):
    # 線形回帰モデル、予測値
    model = LinearRegression()
    model_lin = model.fit(x, y)
    y_lin_fit = model_lin.predict(x)
    plt.plot(x, y_lin_fit, color = color, linewidth=0.5)
def plot_epochs(x,y,title):
    # 回帰分析　線形
    itlist = np.array([(i,t) for i,t in enumerate(trueclass_list) if t == 1]).T
    lx = itlist[0,:].reshape(-1, 1)
    ly = y[lx].reshape(-1, 1)
    lp = plt.scatter(lx,ly,marker="o",label="left hand")
    itlist = np.array([(i,t) for i,t in enumerate(trueclass_list) if t == 2]).T
    rx = itlist[0,:].reshape(-1, 1)
    ry = y[rx].reshape(-1, 1)
    rp = plt.scatter(rx,ry,marker="^",label="right hand")
    plt.legend(loc='upper right',bbox_to_anchor=(1.3, 1))
    plt.draw()
    l_color = lp.get_facecolor()
    r_color = rp.get_facecolor()
    plot_lg(x,y,'#000000')
    plot_lg(lx,ly,l_color)
    plot_lg(rx,ry,r_color)
    plt.title(title)
    plt.show()

def analyse2():
    #判別ポイントごとの平均
    _mn_list = [mpl[:len(mn_predictclass_list[0])] for mpl in mn_predictclass_list]
    plt.errorbar(range(len(mn_predictclass_list[0])),np.mean(_mn_list,axis=0),yerr=np.std(_mn_list,axis=0),
                 capsize=5,ecolor='orange')
    plt.title("match 1or0")
    plt.show()
    print(np.std(_mn_list,axis=0))
    title = "Length of time matched (mean)"
    x = np.array(range(len(_mn_list))).reshape(-1, 1)
    y = np.mean(_mn_list,axis=-1)
    plot_epochs(x,y,title)

    title = "Length of time matched (std)"
    x = np.array(range(len(_mn_list))).reshape(-1, 1)
    y = np.std(_mn_list,axis=-1)
    plot_epochs(x,y,title)
analyse2()

# 以下モデル評価
predictclass_listは初期化される

## デコーダー1

In [None]:
predictclass_list = predict(d1_model_path,stim_data,fs,ch_size,block_size,step)
d1_metrics,d1_detailed_metrics,_ = analyse1(trueclass_list,predictclass_list)
analyse2()

# ログ書き込み

In [None]:
import csv
for d0,d1,apname in zip([data_metrics,data_detailed_metrics],
                  [d1_metrics,d1_detailed_metrics],
                  ["","_detailed"]):
    log_path = "C:/MLA_Saves_Bk/evals/output_acc" + apname + ".csv"
    with open(log_path, 'a') as f:
        writer = csv.writer(f, lineterminator='\n') # 行末は改行
        nlst = data_path.replace("C:/MLA_Saves_Bk/","").replace("\\","/").split("/")
        cols = [nlst[0],nlst[1],nlst[2]]
        cols += [d0[0][0],d0[1][0]] ##accだけ追加
        cols += [d1[0][0],d1[1][0]]
        writer.writerow(cols)
        
for d0,d1,apname in zip([data_metrics,data_detailed_metrics],
                  [d1_metrics,d1_detailed_metrics],
                  ["","_detailed"]):
    log_path = "C:/MLA_Saves_Bk/evals/output_ex" + apname + ".csv"
    with open(log_path, 'a') as f:
        writer = csv.writer(f, lineterminator='\n') # 行末は改行
        nlst = data_path.replace("C:/MLA_Saves_Bk/","").replace("\\","/").split("/")
        cols = [nlst[0],nlst[1],nlst[2]]
        for i in range(1,3):
            cols += [d0[0][i],d0[1][i]]#つまり通常とfixed
            cols += [d1[0][i],d1[1][i]]
            cols += ["/"] #評価関数が変わったら / 列挿入
        writer.writerow(cols)