In [5]:
# -*- coding: utf-8 -*-

import numpy as np
from tqdm import tqdm
import csv
import os
import shutil
import pickle
import random
from sklearn.metrics import mean_squared_error

from numba.experimental import jitclass
from numba import f8, i8, jit


def lif(currents, th, time: int, lif_time: float = 1.0, rest=-65, ref=3, tc_decay=100):
    """ simple LIF neuron """
    time = int(time / lif_time)

    # initialize
    tlast = 0  # 最後に発火した時刻
    vpeak = 20  # 膜電位のピーク(最大値)
    spikes = np.zeros(time)
    v = rest  # 静止膜電位
    
    monitor = []  # monitor voltage

    # Core of LIF
    # 微分方程式をコーディングするときは，このように時間分解能dtで離散的に計算することが多い
    for t in range(time):
        dv = ((lif_time * t) > (tlast + ref)) * (-v + rest + currents[t]) / tc_decay  # 微小膜電位増加量
        v = v + lif_time * dv  # 膜電位を計算

        tlast = tlast + (lif_time * t - tlast) * (v >= th)  # 発火したら発火時刻を記録
        v = v + (vpeak - v) * (v >= th)  # 発火したら膜電位をピークへ

        monitor.append(v)

        spikes[t] = (v >= th) * 1  # スパイクをセット

        v = v + (rest - v) * (v >= th)  # 静止膜電位に戻す

    return spikes


def lif3(currents, th, time: int, lif_time: float = 1.0, rest=-65, ref=3, tc_decay=100):
    """ Reverse LIF neuron """
    time = int(time / lif_time)

    # initialize
    tlast = 0  # 最後に発火した時刻
    vpeak = 20  # 膜電位のピーク(最大値)
    spikes = np.zeros(time)
    v = rest  # 静止膜電位
    max_current = 1000  # 最大電流の調整パラメータ
    currents_new = max_current / (np.abs(currents) + 0.000000001) # 0除算を防ぐための小さい値

    monitor = []  
    for t in range(time):
        dv = ((lif_time * t) > (tlast + ref)) * (-v + rest + currents_new[t]) / tc_decay
          # 微小膜電位増加量
        v = v + lif_time * dv  # 膜電位を計算

        tlast = tlast + (lif_time * t - tlast) * (v >= th)  # 発火したら発火時刻を記録
        v = v + (vpeak - v) * (v >= th)  # 発火したら膜電位をピークへ

        monitor.append(v)

        spikes[t] = (v >= th) * 1  # スパイクをセット

        v = v + (rest - v) * (v >= th)  # 静止膜電位に戻す

    return spikes


def poisson1(currents, time: int, dt: float = 1.0):
    """ Poisson encoding """
    # スパイクレートの計算（正規化なし）
    time = int(time / dt)
    spikes = np.zeros(time)
    spike_rates = np.abs(currents)  # スパイクレートを直接振幅に依存させる

# スパイク列を生成
    for i, rate in enumerate(spike_rates):
        if np.random.poisson(rate * dt / 1000) > 1:
            spikes[i] = 1
    

    return spikes


def poisson2(currents, time: int, dt: float = 1.0):
    """ reverse poisson encoding """
    # スパイクレートの計算（正規化なし）
    time = int(time / dt)
    spikes = np.zeros(time)
    spike_rates = np.abs(currents)
    max_spike_rate = 500  # 最大スパイクレートの設定
    spike_rates = max_spike_rate / (spike_rates + 0.00000001)  # スパイクレートを直接振幅に依存させる

# スパイク列を生成
    for i, rate in enumerate(spike_rates):
        if np.random.poisson(rate * dt / 1000) > 0.2:
            spikes[i] = 1

    return spikes


def poisson3(currents, time: int, dt: float = 1.0):
    """ reverse poisson encoding """
    # スパイクレートの計算（正規化なし）
    time = int(time / dt)
    spikes = np.zeros(time)
    spike_rates = np.abs(currents)
    max_spike_rate = 100000  # 最大スパイクレートの設定
    spike_rates = max_spike_rate / (spike_rates + 0.000001)  # スパイクレートを直接振幅に依存させる

# スパイク列を生成
    for i, rate in enumerate(spike_rates):
        if np.random.poisson(rate * dt / 100000) > 0.5:
            spikes[i] = 1
    

    return spikes


@jitclass([('N',i8),('dt',f8),('td',f8),('r',f8[:])])
class SingleExponentialSynapse:
    def __init__(self, N, dt=1e-4, td=5e-3):
        """
        Args:
            td (float):Synaptic decay time
        """
        self.N = N
        self.dt = dt
        self.td = td
        self.r = np.zeros(N)

    def initialize_states(self):
        self.r = np.zeros(self.N)

    def func(self, spike):
        r = self.r*(1-self.dt/self.td) + spike/self.td
        self.r = r
        return r

    #def __call__(self, spike):
    #    r = self.r*(1-self.dt/self.td) + spike/self.td
    #    self.r = r
    #    return r

@jitclass([('N',i8),('dt',f8),('tref',f8),('tc_m',f8),('vrest',i8),('vreset',i8),('vthr',i8),('vpeak',i8),('e_exc',i8),('e_inh',i8),('v',f8[:]),('v_',f8[:]),('tlast',f8),('tcount',i8)])
class ConductanceBasedLIF:
    def __init__(self, N, dt=1e-4, tref=5e-3, tc_m=1e-2,
                 vrest=-60, vreset=-60, vthr=-50, vpeak=20,
                 e_exc=0, e_inh=-100):
        """
        Conductance-based Leaky integrate-and-fire model.

        Args:
            N (int)       : Number of neurons.
            dt (float)    : Simulation time step in seconds.
            tc_m (float)  : Membrane time constant in seconds.
            tref (float)  : Refractory time constant in seconds.
            vreset (float): Reset membrane potential (mV).
            vrest (float) : Resting membrane potential (mV).
            vthr (float)  : Threshold membrane potential (mV).
            vpeak (float) : Peak membrane potential (mV).
            e_exc (float) : equilibrium potential of excitatory synapses (mV).
            e_inh (float) : equilibrium potential of inhibitory synapses (mV).
        """
        self.N = N
        self.dt = dt
        self.tref = tref
        self.tc_m = tc_m
        self.vrest = vrest
        self.vreset = vreset
        self.vthr = vthr
        self.vpeak = vpeak

        self.e_exc = e_exc # 興奮性シナプスの平衡電位
        self.e_inh = e_inh # 抑制性シナプスの平衡電位

        self.v = self.vreset*np.ones(N)
        self.v_ = np.zeros(N)  #change
        self.tlast = 0
        self.tcount = 0

    def initialize_states(self, random_state=False):
        if random_state:
            self.v = self.vreset + np.random.rand(self.N)*(self.vthr-self.vreset)
        else:
            self.v = self.vreset*np.ones(self.N)
        self.tlast = 0
        self.tcount = 0

    def func(self, g_exc, g_inh):
        I_synExc = g_exc*(self.e_exc - self.v)
        I_synInh = g_inh*(self.e_inh - self.v)
        dv = (self.vrest - self.v + I_synExc + I_synInh) / self.tc_m #Voltage equation with refractory period
        v = self.v + ((self.dt*self.tcount) > (self.tlast + self.tref))*dv*self.dt

        s = 1*(v>=self.vthr) #発火時は1, その他は0の出力
        self.tlast = self.tlast*(1-s) + self.dt*self.tcount*s #最後の発火時の更新
        v = v*(1-s) + self.vpeak*s #閾値を超えると膜電位をvpeakにする
        self.v_ = v #発火時の電位も含めて記録するための変数
        self.v = v*(1-s) + self.vreset*s  #発火時に膜電位をリセット
        self.tcount += 1

        return s


@jitclass([('N',i8),('dt',f8),('tref',f8),('tc_m',f8),('vrest',i8),('vreset',i8),('vrest',i8),('vreset',i8),('init_vthr',i8),('vpeak',i8),('theta_plus',f8),('theta_max',i8),('tc_theta',f8),('vthr',i8),('vpeak',i8),('e_exc',i8),('e_inh',i8),('v',f8[:]),('theta',f8[:]),('v_',f8[:]),('tlast',f8),('tcount',i8)])
class DiehlAndCook2015LIF:
    def __init__(self, N, dt=1e-3, tref=5e-3, tc_m=1e-1,
                 vrest=-65, vreset=-65, init_vthr=-52, vpeak=20,
                 theta_plus=0.025, theta_max=35, tc_theta=1e4, e_exc=0, e_inh=-100):
        """
        Leaky integrate-and-fire model of Diehl and Cooks (2015)
        https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full

        Args:
            N (int)       : Number of neurons.
            dt (float)    : Simulation time step in seconds.
            tc_m (float)  : Membrane time constant in seconds.
            tref (float)  : Refractory time constant in seconds.
            vreset (float): Reset membrane potential (mV).
            vrest (float) : Resting membrane potential (mV).
            vthr (float)  : Threshold membrane potential (mV).
            vpeak (float) : Peak membrane potential (mV).
            e_exc (float) : equilibrium potential of excitatory synapses (mV).
            e_inh (float) : equilibrium potential of inhibitory synapses (mV).
        """
        self.N = N
        self.dt = dt
        self.tref = tref
        self.tc_m = tc_m
        self.vreset = vreset
        self.vrest = vrest
        self.init_vthr = init_vthr
        self.theta = np.zeros(N)
        self.theta_plus = theta_plus
        self.theta_max = theta_max
        self.tc_theta = tc_theta
        self.vpeak = vpeak

        self.e_exc = e_exc # 興奮性シナプスの平衡電位
        self.e_inh = e_inh # 抑制性シナプスの平衡電位

        self.v = self.vreset*np.ones(N)
        self.vthr = self.init_vthr
        self.v_ = np.zeros(N)  #change
        self.tlast = 0
        self.tcount = 0

    def initialize_states(self, random_state=False):
        if random_state:
            self.v = self.vreset + np.random.rand(self.N)*(self.vthr-self.vreset)
        else:
            self.v = self.vreset*np.ones(self.N)
        self.vthr = self.init_vthr
        self.theta = np.zeros(self.N)
        self.tlast = 0
        self.tcount = 0

    def func(self, g_exc, g_inh):
        I_synExc = g_exc*(self.e_exc - self.v)
        I_synInh = g_inh*(self.e_inh - self.v)
        dv = (self.vrest - self.v + I_synExc + I_synInh) / self.tc_m #Voltage equation with refractory period
        v = self.v + ((self.dt*self.tcount) > (self.tlast + self.tref))*dv*self.dt

        s = 1*(v>=self.vthr) #発火時は1, その他は0の出力
        theta = (1-self.dt/self.tc_theta)*self.theta + self.theta_plus*s
        self.theta = np.clip(theta, 0, self.theta_max)
        self.vthr = self.theta + self.init_vthr
        self.tlast = self.tlast*(1-s) + self.dt*self.tcount*s #最後の発火時の更新
        v = v*(1-s) + self.vpeak*s #閾値を超えると膜電位をvpeakにする
        self.v_ = v #発火時の電位も含めて記録するための変数
        self.v = v*(1-s) + self.vreset*s  #発火時に膜電位をリセット
        self.tcount += 1

        return s

#@jitclass([('W',f8[:,:])])
class FullConnection:
    def __init__(self, N_in, N_out, initW=None):
        """
        FullConnection
        """
        if initW is not None:
            self.W = initW
        else:
            self.W = 0.1*np.random.rand(N_out, N_in)

    def backward(self, x):
        return np.dot(self.W.T, x) #self.W.T @ x

    def func(self, x):
        return np.dot(self.W, x) #self.W @ x


@jitclass([('N',i8),('nt_delay',i8),('state',f8[:,:])])
class DelayConnection:
    def __init__(self, N, delay, dt=1e-4):
        """
        Args:
            delay (float): Delay time
        """
        self.N = N
        self.nt_delay = round(delay/dt) # 遅延のステップ数
        self.state = np.zeros((N, self.nt_delay))

    def initialize_states(self):
        self.state = np.zeros((self.N, self.nt_delay))

    def func(self, x):
        out = self.state[:, -1] # 出力

        self.state[:, 1:] = self.state[:, :-1] # 配列をずらす
        self.state[:, 0] = x # 入力

        return out

np.random.seed(seed=0)

para_dic= {'w_exc': 4.544910490280396, 'w_inh': 1.3082028299035284, 'lr1': 0.0312611326044722, 'lr2': 0.03592142258355323, 'Norm': 0.22856140170464153}
para_list=list(para_dic.values())

k=7


# ラベルの割り当て
def assign_labels(spikes, labels, unique_labels, rates=None, alpha=1.0):
    """
    Assign labels to the neurons based on highest average spiking activity.

    Args:
        spikes (n_samples, n_neurons) : A single layer's spiking activity.
        labels (n_samples,) : Data labels corresponding to input samples. 480labelで1,2,3,..のどれか
        n_labels (int)      : The number of target labels in the data.  
        rates (n_neurons, n_labels) : If passed, these represent spike rates
                                      from a previous ``assign_labels()`` call.
        alpha (float): Rate of decay of label assignments.
    return: Class assignments, per-class spike proportions, and per-class firing rates.
    """
    n_neurons = spikes.shape[1]
    print(n_neurons)
    print(spikes)

    if rates is None:
        n_labels = len(unique_labels)
        rates = np.zeros((n_neurons, n_labels)).astype(np.float32)

    # 時間の軸でスパイク数の和を取る
    for i, label in enumerate(unique_labels):
        # サンプル内の同じラベルの数を求める
        n_labeled = np.sum(labels == label).astype(np.int16)
        

        if n_labeled > 0:
            # label == iのサンプルのインデックスを取得
            indices = np.where(labels == label)[0]
            print(indices)
            print(spikes[indices])

            # label == iに対する各ニューロンごとの平均発火率を計算(前回の発火率との移動平均)
            rates[:, i] = alpha*rates[:, i] + (np.sum(spikes[indices], axis=0)/n_labeled)

    sum_rate = np.sum(rates, axis=1)
    sum_rate[sum_rate==0] = 1
    # クラスごとの発火頻度の割合を計算する
    proportions = rates / np.expand_dims(sum_rate, 1) # (n_neurons, n_labels)
    proportions[proportions != proportions] = 0  # Set NaNs to 0

    # 最も発火率が高いラベルを各ニューロンに割り当てる
    mapped_labels = np.argmax(proportions, axis=1).astype(np.uint8) # (n_neurons,)
    assignments = np.array(unique_labels)[mapped_labels]
    return assignments, proportions, rates

# assign_labelsで割り当てたラベルからサンプルのラベルの予測をする
def prediction(spikes, assignments, unique_labels):
    """
    Classify data with the label with highest average spiking activity over all neurons.

    Args:
        spikes  (n_samples, n_neurons) : A layer's spiking activity.
        assignments (n_neurons,) : Neuron label assignments.
        n_labels (int): The number of target labels in the data.
    return: Predictions (n_samples,)
    """

    n_samples = spikes.shape[0]
    print(n_samples)
    n_labels = len(unique_labels)

    # 各サンプルについて各ラベルの発火率を見る
    rates = np.zeros((n_samples, n_labels)).astype(np.float32)

    for i, label in enumerate(unique_labels):
        # 各ラベルが振り分けられたニューロンの数
        print(assignments)
        n_assigns = np.sum(assignments == label).astype(np.uint8)
        print(n_assigns)

        if n_assigns > 0:
            # 各ラベルのニューロンのインデックスを取得
            indices = np.where(assignments == label)[0]
            print(indices)

            # 各ラベルのニューロンのレイヤー全体における平均発火数を求める
            rates[:, i] = np.sum(spikes[:, indices], axis=1) / n_assigns
    print(rates)
    predicted_indices = np.argmax(rates, axis=1)
    max_rates = np.max(rates, axis=1)
    for i in range(n_samples):
        if np.sum(rates[i, :] == max_rates[i]) > 1:
            predicted_indices[i] = 2  # 複数のラベルが同じ最大値を持つ場合は2に設定
    print(predicted_indices)
    predicted_labels = np.array(correct_lables)[predicted_indices]

    # レイヤーの平均発火率が最も高いラベルを出力
    return predicted_labels.astype(np.uint8)# (n_samples, )


#################
####  Model  ####
#################

class DiehlAndCook2015Network:
    #exc_neurons: exc_neurons_type
    #inh_neurons: inh_neurons_type
    #input_synapse:input_synapse_type
    #exc_synapse:exc_synapse_type
    #inh_synapse: inh_synapse_type
    #input_synaptictrace:  input_synaptictrace_type
    #exc_synaptictrace:exc_synaptictrace_type
    #input_conn:input_conn_type 
    #delay_input :delay_input_type 
    #delay_exc2inh:delay_exc2inh_type
    
    def __init__(self, n_in=4, n_neurons=100, wexc=2.25, winh=0.875,
                 dt=1e-3, wmin=0.0, wmax=5e-2, lr=(1e-2, 1e-4),
                 update_nt=100, norm = 0.1):
        """
        Network of Diehl and Cooks (2015)
        https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full

        Args:
            n_in: Number of input neurons. Matches the 1D size of the input data.
            n_neurons: Number of excitatory, inhibitory neurons.
            wexc: Strength of synapse weights from excitatory to inhibitory layer.
            winh: Strength of synapse weights from inhibitory to excitatory layer.
            dt: Simulation time step.
            lr: Single or pair of learning rates for pre- and post-synaptic events, respectively.
            wmin: Minimum allowed weight on input to excitatory synapses.
            wmax: Maximum allowed weight on input to excitatory synapses.
            update_nt: Number of time steps of weight updates.
        """

        self.dt = dt
        self.lr_p, self.lr_m = lr
        self.wmax = wmax
        self.wmin = wmin
        print(self.lr_p)
        print(self.lr_m)
        print(self.wmin)
        print(self.wmax)


        # Neurons
        self.exc_neurons = DiehlAndCook2015LIF(n_neurons, dt=dt, tref=5e-3,
                                               tc_m=1e-1,
                                               vrest=-65, vreset=-65,
                                               init_vthr=-52,
                                               vpeak=20, theta_plus=0.05,
                                               theta_max=35,
                                               tc_theta=1e4,
                                               e_exc=0, e_inh=-100)

        self.inh_neurons = ConductanceBasedLIF(n_neurons, dt=dt, tref=2e-3,
                                               tc_m=1e-2,
                                               vrest=-60, vreset=-45,
                                               vthr=-40, vpeak=20,
                                               e_exc=0, e_inh=-85)
        # Synapses
        self.input_synapse = SingleExponentialSynapse(n_in, dt=dt, td=1e-3)
        self.exc_synapse = SingleExponentialSynapse(n_neurons, dt=dt, td=1e-3)
        self.inh_synapse = SingleExponentialSynapse(n_neurons, dt=dt, td=2e-3)

        self.input_synaptictrace = SingleExponentialSynapse(n_in, dt=dt,
                                                            td=2e-2)
        self.exc_synaptictrace = SingleExponentialSynapse(n_neurons, dt=dt,
                                                          td=2e-2)

        # Connections
        initW = 1e-3*np.random.rand(n_neurons, n_in)
        self.input_conn = FullConnection(n_in, n_neurons,
                                         initW=initW)
        self.exc2inh_W = wexc*np.eye(n_neurons)
        self.inh2exc_W = (winh/(n_neurons-1))*(np.ones((n_neurons, n_neurons)) - np.eye(n_neurons))

        self.delay_input = DelayConnection(N=n_neurons, delay=5e-3, dt=dt)
        self.delay_exc2inh = DelayConnection(N=n_neurons, delay=2e-3, dt=dt)

        self.norm = norm
        self.g_inh = np.zeros(n_neurons)
        self.tcount = 0
        self.update_nt = update_nt
        self.n_neurons = n_neurons
        self.n_in = n_in
        self.s_in_ = np.zeros((self.update_nt, n_in))
        self.s_exc_ = np.zeros((n_neurons, self.update_nt))
        self.x_in_ = np.zeros((self.update_nt, n_in))
        self.x_exc_ = np.zeros((n_neurons, self.update_nt))

    # スパイクトレースのリセット
    def reset_trace(self):
        self.s_in_ = np.zeros((self.update_nt, self.n_in))
        self.s_exc_ = np.zeros((self.n_neurons, self.update_nt))
        self.x_in_ = np.zeros((self.update_nt, self.n_in))
        self.x_exc_ = np.zeros((self.n_neurons, self.update_nt))
        self.tcount = 0

    # 状態の初期化
    def initialize_states(self):
        self.exc_neurons.initialize_states()
        self.inh_neurons.initialize_states()
        self.delay_input.initialize_states()
        self.delay_exc2inh.initialize_states()
        self.input_synapse.initialize_states()
        self.exc_synapse.initialize_states()
        self.inh_synapse.initialize_states()

    def __call__(self, s_in, stdp=True):
        # 入力層
        c_in = self.input_synapse.func(s_in) # シナプス１6こ
        x_in = self.input_synaptictrace.func(s_in)  # シナプス２　６こ
        g_in = self.input_conn.func(c_in)  # 重み畳み込み

        # 興奮性ニューロン層
        s_exc = self.exc_neurons.func(self.delay_input.func(g_in), self.g_inh)
        c_exc = self.exc_synapse.func(s_exc) # シナプス3 10こ
        g_exc = np.dot(self.exc2inh_W, c_exc) 
        x_exc = self.exc_synaptictrace.func(s_exc) # シナプス4 10こ

        # 抑制性ニューロン層
        s_inh = self.inh_neurons.func(self.delay_exc2inh.func(g_exc), 0)
        c_inh = self.inh_synapse.func(s_inh) # シナプス5 10こ
        self.g_inh = np.dot(self.inh2exc_W, c_inh)

        if stdp:
            # スパイク列とスパイクトレースを記録
            self.s_in_[self.tcount] = s_in
            self.s_exc_[:, self.tcount] = s_exc
            self.x_in_[self.tcount] = x_in
            self.x_exc_[:, self.tcount] = x_exc
            self.tcount += 1

            # Online STDP
            if self.tcount == self.update_nt: # 10000回溜まったら　１エポック１回
                W = np.copy(self.input_conn.W)

                # postに投射される重みが均一になるようにする
                W_abs_sum = np.expand_dims(np.sum(np.abs(W), axis=1), 1)
                W_abs_sum[W_abs_sum == 0] = 1.0
                W *= self.norm / W_abs_sum

                # STDP則
                dW = self.lr_p*(self.wmax - W)*np.dot(self.s_exc_, self.x_in_)
                dW -= self.lr_m*W*np.dot(self.x_exc_, self.s_in_)
                clipped_dW = np.clip(dW / self.update_nt, -1e-3, 1e-3)
                self.input_conn.W = np.clip(W + clipped_dW,
                                            self.wmin, self.wmax)
                self.reset_trace() # スパイク列とスパイクトレースをリセット
                

        return s_exc


# 350ms画像入力、150ms入力なしでリセットさせる(膜電位の閾値以外)
dt = 2e-4 # タイムステップ(sec)
t_inj = 2.0 # 刺激入力時間(sec)
t_blank = 0.5 # ブランク時間(sec)
nt_inj = round(t_inj/dt)
nt_blank = round(t_blank/dt)

n_neurons = 20 #興奮性/抑制性ニューロンの数
unique_labels = [1, 2]
correct_lables = [1,2,0] #ラベル数
n_epoch = 10 #エポック数

n_train = 180 # 訓練データの数
n_val = 60
n_test = 100
update_nt = nt_inj # STDP則による重みの更新間隔



if __name__ == '__main__':
    duration = 2000  # ms
    lif_time = 0.2  # time step




    data_set = []
    t_set = []


    for l in [1, 2]:
        for i in range(31, 121):
            virtual_list = np.zeros([10000, 6])  # スパイクデータ格納用配列を初期化

            # チャンネル0
            input_data1 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes1 = lif(input_data1, -45, duration, lif_time)
            virtual_list[:, 0] = np.asarray(spikes1, dtype=int)

            # チャンネル1
            input_data2 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes2 = lif3(input_data2, -50, duration, lif_time)
            virtual_list[:, 1] = np.asarray(spikes2, dtype=int)

            # チャンネル2
            input_data3 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes3 = poisson1(input_data3, duration, lif_time)
            virtual_list[:, 2] = np.asarray(spikes3, dtype=int)

            # チャンネル3
            input_data4 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes4 = poisson2(input_data4, duration, lif_time)
            virtual_list[:, 3] = np.asarray(spikes4, dtype=int)

            # チャンネル4
            input_data5 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes5 = lif(input_data5, -50, duration, lif_time)
            virtual_list[:, 4] = np.asarray(spikes5, dtype=int)

            # チャンネル5
            input_data6 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes6 = poisson3(input_data6, duration, lif_time)
            virtual_list[:, 5] = np.asarray(spikes6, dtype=int)
            if l == 1:
                data_set.append((virtual_list, l))
            elif l == 2:
                data_set.append((virtual_list, l))

    
    for l in [1, 2]:
        for i in range(1, 31):
            virtual_list = np.zeros([10000, 6])  # スパイクデータ格納用配列を初期化

            # チャンネル0
            input_data7 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes7 = lif(input_data7, -45, duration, lif_time)
            virtual_list[:, 0] = np.asarray(spikes7, dtype=int)

            # チャンネル1
            input_data8 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes8 = lif3(input_data8, -50, duration, lif_time)
            virtual_list[:, 1] = np.asarray(spikes8, dtype=int)

            # チャンネル2
            input_data9 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes9 = poisson1(input_data9,  duration, lif_time)
            virtual_list[:, 2] = np.asarray(spikes9, dtype=int)

            # チャンネル3
            input_data10 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes10 = poisson2(input_data10, duration, lif_time)
            virtual_list[:, 3] = np.asarray(spikes10, dtype=int)

            # チャンネル4
            input_data11 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes11 = lif(input_data11, -50, duration, lif_time)
            virtual_list[:, 4] = np.asarray(spikes11, dtype=int)

            # チャンネル5
            input_data12 = 70 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata_new\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes12 = poisson3(input_data12, duration, lif_time)
            virtual_list[:, 5] = np.asarray(spikes12, dtype=int)
            if l == 1:
                t_set.append((virtual_list, l))
            elif l == 2:
                t_set.append((virtual_list, l))
    
    random.seed(0)
    spikes_list = random.sample(data_set,len(data_set))
    with open('spikes_list_me.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        for spik in spikes_list:
            writer.writerow([np.sum(spik[0], axis=0)] + [spik[1]])
    random.seed(0)
    spikes_list_val = random.sample(t_set,len(t_set))

    count1 = np.sum(spikes_list[1][0] == 0)
    count2 = np.sum(spikes_list[2][0] == 1)
    count3 = np.sum(spikes_list[3][0] == 1)

    print(count1)
    print(count2)
    print(count3)


    labels = np.array([spikes_list[i][1] for i in range(n_train)]) # ラベルの配列
    test_labels = np.array([spikes_list_val[i][1] for i in range(n_val)])
    results_save_dir = "./redclassiffication_results/" # 結果を保存するディレクトリ
    os.makedirs(results_save_dir, exist_ok=True) # ディレクトリ作成
    
    network = DiehlAndCook2015Network(n_in=6, n_neurons=n_neurons,
                                      wexc=para_list[0], winh=para_list[1],
                                      dt=dt, wmin=0.0, wmax=10.0,
                                      lr=(para_list[2], para_list[3]),
                                      update_nt=update_nt,norm = para_list[4])
    
    network.initialize_states() 
    accuracy_all = np.zeros(n_epoch)
    accuracy_test_all = np.zeros(n_epoch) # 訓練精度を記録する変数
    spikes= np.zeros((n_train, n_neurons)).astype(np.uint8)
    print(spikes)
    blank_input = np.zeros(6) # ブランク入力
    print(spikes_list_val[5][1]) #(10000, 28)  
    rows = len(spikes_list_val[5][0])
    print(rows)

    """28はニューロンの数,10000はスパイクの記録"""

    for epoch in range(n_epoch):
        for i in tqdm(range(n_train)):
            #max_fr = init_max_fr
            while(True):
                # 入力スパイクをオンラインで生成
                input_spikes = spikes_list[i][0] #input_spikeは10000*6
                spike_list = [] # サンプルごとにスパイクを記録するリスト
                # 画像刺激の入力
                for t in range(nt_inj):
                    s_exc = network(input_spikes[t], stdp=True)
                    spike_list.append(s_exc)

                spikes[i] = np.sum(np.array(spike_list), axis=0) 

                # ブランク刺激の入力
                for _ in range(nt_blank):
                    _ = network(blank_input, stdp=False)

                num_spikes_exc = np.sum(np.array(spike_list)) # スパイク数を計算
                if num_spikes_exc >= 0: # スパイク数が5より大きければ次のサンプルへ
                    break
                #else: # スパイク数が5より小さければ入力発火率を上げて再度刺激
                #    max_fr += 16
        


        # ニューロンを各ラベルに割り当てる
        if epoch == 0:
            assignments, proportions, rates = assign_labels(spikes, labels,
                                                            unique_labels)
            print(assignments)
            print(proportions)
            print(rates)

        else:
            assignments, proportions, rates = assign_labels(spikes, labels,
                                                            unique_labels, rates)
            print(rates)
        print("Assignments:\n", assignments)
        print(labels)

        # スパイク数の確認(正常に発火しているか確認)
        sum_nspikes = np.sum(spikes, axis=1)
        mean_nspikes = np.mean(sum_nspikes).astype(np.float16)
        print("Ave. spikes:", mean_nspikes)
        print("Min. spikes:", sum_nspikes.min())
        print("Max. spikes:", sum_nspikes.max())

        # 入力サンプルのラベルを予測する
        predicted_labels = prediction(spikes, assignments, unique_labels)
        print('prediction:\n',predicted_labels)

        # 訓練精度を計算
        accuracy = np.mean(np.where(labels==predicted_labels, 1, 0)).astype(np.float16)
        print("epoch :", epoch, " accuracy :", accuracy)
        print("wights",network.input_conn.W)
        if accuracy < 0.5:
            #score_list.append(5.0)
            break
        accuracy_all[epoch] = accuracy

        #return accuracy_all[-1]
        # 学習率の減衰
        network.lr_p *= 0.25
        network.lr_m *= 0.25


        # 重みの保存(エポック毎)
        np.save(results_save_dir+"weight_epoch"+str(epoch)+".npy", network.input_conn.W)
        np.save(results_save_dir+"assignment_epoch"+str(epoch)+".npy", assignments)
        np.save(results_save_dir+"exc_neurons_epoch"+str(epoch)+".npy", network.exc_neurons.theta)


        #test phase
        network_test = DiehlAndCook2015Network(n_in=6, n_neurons=n_neurons, dt=dt)
        network_test.initialize_states()

        #network_test.input_conn.W = np.load(results_save_dir+"weight_epoch"+str(epoch)+".npy")
        #network_test.exc_neurons.theta = np.load(results_save_dir+"exc_neurons_epoch"+str(epoch)+".npy")
        network_test.input_conn.W = network.input_conn.W
        network_test.exc_neurons.theta = network.exc_neurons.theta
        network_test.exc_neurons.theta_plus = 0 # 閾値が上昇しないようにする
        
        spikes_val = np.zeros((n_val, n_neurons)).astype(np.uint8)
        blank_input_test = np.zeros(6) # ブランク入力


        for i in tqdm(range(n_val)):

            while(True):
                # 入力スパイクをオンラインで生成
                input_spikes_val = spikes_list_val[i][0]
                spike_list = [] # サンプルごとにスパイクを記録するリスト
                # 画像刺激の入力
                for t in range(nt_inj):
                    s_exc_test = network_test(input_spikes_val[t], stdp=False)
                    spike_list.append(s_exc_test)

                spikes_val[i] = np.sum(np.array(spike_list), axis=0) # スパイク数を記録

                # ブランク刺激の入力
                for _ in range(nt_blank):
                    _ = network_test(blank_input_test, stdp=False)

                num_spikes_exc = np.sum(np.array(spike_list)) # スパイク数を計算

                if num_spikes_exc >= 0: # スパイク数が5より大きければ次のサンプルへ
                    break
                #else: # スパイク数が5より小さければ入力発火率を上げて再度刺激
                #    max_fr += 16

        # 入力サンプルのラベルを予測する
        #assignments = np.load(results_save_dir+"assignment_epoch"+str(epoch)+".npy")
        assignments = assignments
        predicted_labels_val = prediction(spikes_val, assignments, unique_labels)
        print('Val prediction:\n',predicted_labels_val)

        # 訓練精度を計算
        accuracy_test = np.mean(np.where(test_labels==predicted_labels_val, 1, 0)).astype(np.float16)
        print('Val accuracy : ',accuracy_test)
        accuracy_test_all[epoch] = accuracy_test
    print(accuracy_all)
    print(accuracy_test)






59860
93
72
0.0312611326044722
0.03592142258355323
0.0
10.0
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
1
10000


100%|██████████| 180/180 [01:10<00:00,  2.54it/s]


20
[[ 0  0  0 ...  0  0  0]
 [47 20 21 ... 30 29 16]
 [32 43 40 ... 39 51 42]
 ...
 [18 12 17 ... 18 16 13]
 [10 10 10 ...  9 10 10]
 [18 14 14 ... 16 15 14]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[32 43 40 ... 39 51 42]
 [23 27 21 ... 23 35 28]
 [25 24 28 ... 26 35 29]
 ...
 [14 14 14 ... 13 14 14]
 [10 10 10 ...  9 10 10]
 [10 10 10 ...  9 10 10]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:20<00:00,  2.89it/s]


60
[2 1 1 1 1 1 1 1 1 1 2 2 1 1 1 2 1 1 1 1]
16
[ 1  2  3  4  5  6  7  8  9 12 13 14 16 17 18 19]
[2 1 1 1 1 1 1 1 1 1 2 2 1 1 1 2 1 1 1 1]
4
[ 0 10 11 15]
[[12.1875 18.    ]
 [11.6875 12.25  ]
 [10.875  15.25  ]
 [10.1875 14.5   ]
 [ 5.6875  5.5   ]
 [13.375  13.5   ]
 [10.9375 10.75  ]
 [ 9.75   14.    ]
 [ 9.5625 14.75  ]
 [ 7.5625  8.25  ]
 [15.875  15.    ]
 [ 8.875  13.    ]
 [ 9.1875  9.    ]
 [15.3125 20.25  ]
 [13.1875 13.    ]
 [11.9375 17.25  ]
 [11.4375 11.    ]
 [11.9375 12.    ]
 [12.375  16.25  ]
 [ 5.625   5.25  ]
 [15.0625 19.5   ]
 [ 9.5625 12.75  ]
 [ 9.     12.    ]
 [12.125  11.75  ]
 [ 7.375  13.    ]
 [ 9.125  16.75  ]
 [ 8.875   7.5   ]
 [13.     13.    ]
 [ 8.875  12.25  ]
 [11.1875 11.    ]
 [ 6.5     6.    ]
 [ 9.375   9.    ]
 [12.     15.5   ]
 [ 7.375   7.    ]
 [12.3125 18.25  ]
 [10.     10.    ]
 [ 7.375  14.25  ]
 [ 5.75   12.    ]
 [ 9.9375  9.    ]
 [10.25   10.25  ]
 [ 9.3125 14.    ]
 [12.9375 19.    ]
 [ 9.5     9.    ]
 [15.6875 14.75  ]
 [10.375

100%|██████████| 180/180 [01:10<00:00,  2.56it/s]


20
[[18 13 15 ... 18 15 12]
 [17 10 14 ... 17 12 11]
 [13 13 13 ... 13 13 13]
 ...
 [16 16 16 ... 12 16 17]
 [12 12 12 ... 11 12 12]
 [16 15 16 ... 10 16 15]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[13 13 13 ... 13 13 13]
 [ 8  8  7 ...  6  8  8]
 [11 11 11 ... 11 11 11]
 ...
 [ 9 10  9 ... 10  9  9]
 [10 10 10 ...  7 10 10]
 [12 12 12 ... 11 12 12]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.78it/s]


60
[2 1 2 1 1 1 1 1 1 1 2 2 1 1 1 2 1 2 2 1]
13
[ 1  3  4  5  6  7  8  9 12 13 14 16 19]
[2 1 2 1 1 1 1 1 1 1 2 2 1 1 1 2 1 2 2 1]
7
[ 0  2 10 11 15 17 18]
[[10.307693  14.142858 ]
 [ 9.615385   9.571428 ]
 [12.        13.714286 ]
 [11.076923  13.857142 ]
 [ 6.         5.857143 ]
 [11.769231  10.857142 ]
 [ 8.846154   7.142857 ]
 [12.230769  14.428572 ]
 [10.        12.857142 ]
 [ 5.8461537  5.714286 ]
 [11.         8.857142 ]
 [ 9.769231  12.571428 ]
 [ 8.307693   7.       ]
 [15.230769  17.285715 ]
 [10.384615   9.       ]
 [12.307693  15.       ]
 [ 9.846154   9.       ]
 [10.         8.857142 ]
 [13.076923  14.142858 ]
 [ 4.6923075  4.       ]
 [12.769231  14.571428 ]
 [10.461538  13.428572 ]
 [ 9.692307  12.285714 ]
 [ 9.384615   8.       ]
 [10.384615  13.       ]
 [10.384615  14.285714 ]
 [ 7.3846154  6.714286 ]
 [12.923077  12.       ]
 [10.615385  13.142858 ]
 [ 8.230769   6.714286 ]
 [ 7.8461537  7.       ]
 [ 9.538462   9.       ]
 [13.153846  15.285714 ]
 [ 8.         7.   

100%|██████████| 180/180 [01:10<00:00,  2.57it/s]


20
[[17 15 17 ...  8 17 15]
 [15 13 15 ...  9 15 13]
 [11 11 10 ... 11 10 10]
 ...
 [17  9 17 ...  9 10  9]
 [14 15 12 ... 15 15 15]
 [19  6 18 ...  6  5  6]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[11 11 10 ... 11 10 10]
 [ 3  3  3 ...  3  3  3]
 [10 10 10 ...  8 10 10]
 ...
 [19 25 15 ... 25 21 25]
 [13 16 12 ... 16 14 16]
 [14 15 12 ... 15 15 15]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:22<00:00,  2.67it/s]


60
[2 1 2 1 1 1 1 1 1 1 2 1 1 1 1 2 1 1 1 1]
16
[ 1  3  4  5  6  7  8  9 11 12 13 14 16 17 18 19]
[2 1 2 1 1 1 1 1 1 1 2 1 1 1 1 2 1 1 1 1]
4
[ 0  2 10 15]
[[ 2.75   13.75  ]
 [ 9.8125 10.75  ]
 [ 8.8125 16.25  ]
 [ 7.     14.    ]
 [12.4375 12.    ]
 [21.625  19.25  ]
 [20.8125 17.    ]
 [ 9.6875 16.    ]
 [ 1.75   11.5   ]
 [13.6875 10.5   ]
 [23.25   18.75  ]
 [ 6.9375 15.25  ]
 [20.625  15.75  ]
 [ 9.75   19.    ]
 [18.1875 16.    ]
 [ 7.9375 16.    ]
 [16.3125 13.25  ]
 [22.25   16.75  ]
 [11.3125 18.25  ]
 [17.125  14.25  ]
 [ 5.0625 14.5   ]
 [ 2.625  10.75  ]
 [ 3.5625 11.5   ]
 [22.4375 18.5   ]
 [ 2.1875  9.75  ]
 [ 7.0625 17.5   ]
 [14.1875 11.25  ]
 [23.     18.    ]
 [ 6.8125 13.75  ]
 [18.375  15.25  ]
 [17.25   15.25  ]
 [16.     13.5   ]
 [ 8.5    14.25  ]
 [14.1875 11.75  ]
 [ 8.8125 15.5   ]
 [15.8125 10.75  ]
 [ 6.     12.75  ]
 [ 2.6875 11.    ]
 [21.1875 18.25  ]
 [19.     15.5   ]
 [ 2.875  12.    ]
 [ 9.5    18.5   ]
 [17.1875 15.25  ]
 [25.875  18.75  ]
 [ 6.312

100%|██████████| 180/180 [01:10<00:00,  2.57it/s]


20
[[20  7 18 ...  7  5  7]
 [18  4 17 ...  4  5  4]
 [19 22 14 ... 22 20 22]
 ...
 [11  7 15 ...  7  7  7]
 [14 15  8 ... 15 15 15]
 [ 7  4 12 ...  4  4  4]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[19 22 14 ... 22 20 22]
 [15 20 12 ... 21 16 20]
 [18 24 14 ... 24 22 24]
 ...
 [24 34 16 ... 34 31 34]
 [17 21 11 ... 22 17 21]
 [14 15  8 ... 15 15 15]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.73it/s]


60
[2 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
17
[ 1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[2 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
3
[ 0  2 10]
[[ 4.1764708 13.       ]
 [13.117647  12.       ]
 [ 7.647059  15.       ]
 [ 9.117647  15.       ]
 [13.235294   8.       ]
 [28.235294  15.333333 ]
 [20.882353  12.333333 ]
 [10.882353  14.333333 ]
 [ 1.        10.       ]
 [13.294118   9.333333 ]
 [24.529411  15.333333 ]
 [ 6.2352943 13.       ]
 [26.058823  13.333333 ]
 [11.117647  19.333334 ]
 [21.117647  12.       ]
 [11.411765  17.666666 ]
 [21.529411  12.666667 ]
 [18.941177  10.333333 ]
 [12.176471  16.333334 ]
 [18.647058  14.666667 ]
 [ 6.882353  17.       ]
 [ 3.        11.666667 ]
 [ 5.882353  12.       ]
 [24.882353  16.       ]
 [ 3.764706  11.       ]
 [ 7.0588236 17.333334 ]
 [14.         7.3333335]
 [27.764706  13.       ]
 [ 9.058824  16.       ]
 [19.235294  10.       ]
 [23.941177  13.666667 ]
 [19.941177  13.333333 ]
 [ 8.058824  12.666667 ]
 [16.117647   9.   

100%|██████████| 180/180 [01:10<00:00,  2.57it/s]


20
[[13  7 14 ...  7  9  7]
 [10  4 14 ...  4  4  4]
 [21 27  9 ... 28 28 27]
 ...
 [ 8  6  8 ...  6  6  6]
 [17 19 11 ... 19 17 19]
 [ 6  6 11 ...  6  7  6]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[21 27  9 ... 28 28 27]
 [15 19  8 ... 19 18 19]
 [16 17  4 ... 17 16 17]
 ...
 [24 30 13 ... 30 28 30]
 [16 21  9 ... 21 19 21]
 [17 19 11 ... 19 17 19]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.75it/s]


60
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
18
[ 0  1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
2
[ 2 10]
[[ 4.2777777 16.       ]
 [13.333333  10.       ]
 [ 7.9444447 14.5      ]
 [ 9.        19.       ]
 [14.5        5.5      ]
 [26.61111    7.       ]
 [21.777779   9.       ]
 [10.5       13.       ]
 [ 1.0555556 12.5      ]
 [13.055555   8.       ]
 [27.11111   11.       ]
 [ 5.8333335 16.5      ]
 [23.5       11.       ]
 [11.222222  20.       ]
 [19.5        5.       ]
 [ 7.2777777 17.5      ]
 [20.277779   8.       ]
 [19.722221   7.       ]
 [12.111111  18.       ]
 [18.333334   9.       ]
 [ 7.2777777 17.5      ]
 [ 2.3888888 14.       ]
 [ 4.1666665 14.       ]
 [30.444445  12.       ]
 [ 3.8888888 14.5      ]
 [ 8.333333  20.       ]
 [14.444445   3.5      ]
 [29.333334   9.       ]
 [ 8.444445  17.5      ]
 [19.055555   7.       ]
 [24.555555  11.       ]
 [20.         8.5      ]
 [ 6.1666665 14.5      ]
 [16.         4.   

100%|██████████| 180/180 [01:34<00:00,  1.90it/s]


20
[[13  9 15 ...  8 10  9]
 [ 8  4 11 ...  4  4  4]
 [23 28 10 ... 26 25 28]
 ...
 [ 9  6  9 ...  6  6  6]
 [18 20 11 ... 21 19 21]
 [ 9 10 17 ... 10 10 10]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[23 28 10 ... 26 25 28]
 [16 20  7 ... 20 17 20]
 [18 22  6 ... 22 19 22]
 ...
 [22 28 11 ... 28 26 28]
 [14 16  8 ... 16 15 16]
 [18 20 11 ... 21 19 21]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.74it/s]


60
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
18
[ 0  1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
2
[ 2 10]
[[ 2.6666667 15.       ]
 [12.666667  10.       ]
 [ 8.        14.       ]
 [ 8.944445  19.       ]
 [16.11111    7.5      ]
 [26.88889    7.5      ]
 [20.88889    9.       ]
 [10.388889  13.5      ]
 [ 1.0555556 13.       ]
 [13.166667   7.5      ]
 [27.555555  12.       ]
 [ 5.0555553 16.       ]
 [23.666666   9.5      ]
 [11.277778  19.5      ]
 [21.         6.       ]
 [ 7.6666665 18.5      ]
 [20.5        7.5      ]
 [19.777779   8.       ]
 [12.111111  17.       ]
 [17.166666   7.       ]
 [ 7.9444447 18.       ]
 [ 3.1111112 14.5      ]
 [ 4.1666665 14.       ]
 [26.166666  10.5      ]
 [ 3.8888888 14.       ]
 [ 8.444445  20.       ]
 [13.166667   2.5      ]
 [29.222221   8.       ]
 [10.777778  18.5      ]
 [19.055555   7.       ]
 [23.222221  11.       ]
 [19.166666   7.5      ]
 [ 6.611111  14.5      ]
 [15.         4.   

100%|██████████| 180/180 [01:08<00:00,  2.62it/s]


20
[[12 13 16 ... 13 11 13]
 [ 8  4 12 ...  4  4  4]
 [23 28 11 ... 28 27 28]
 ...
 [ 8  5  8 ...  6  5  5]
 [17 20 10 ... 20 18 20]
 [ 7  8 12 ...  8  8  8]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[23 28 11 ... 28 27 28]
 [19 26 10 ... 26 22 26]
 [18 20  6 ... 20 19 20]
 ...
 [24 27 13 ... 28 26 27]
 [14 17  8 ... 17 15 17]
 [17 20 10 ... 20 18 20]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:20<00:00,  2.86it/s]


60
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
18
[ 0  1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
2
[ 2 10]
[[ 2.6666667 15.       ]
 [12.666667  10.       ]
 [ 8.        14.       ]
 [ 8.944445  19.       ]
 [16.11111    7.5      ]
 [26.88889    7.5      ]
 [20.944445   9.       ]
 [11.277778  13.5      ]
 [ 1.0555556 13.       ]
 [13.166667   8.       ]
 [27.944445  11.       ]
 [ 5.111111  16.       ]
 [23.722221   9.5      ]
 [11.333333  19.5      ]
 [21.         6.       ]
 [ 7.5555553 17.5      ]
 [21.88889    9.       ]
 [18.88889    8.       ]
 [12.055555  17.       ]
 [15.333333   7.       ]
 [ 7.888889  18.       ]
 [ 3.1111112 14.5      ]
 [ 4.1666665 14.       ]
 [25.833334  12.       ]
 [ 3.8888888 14.       ]
 [ 8.5       20.       ]
 [13.166667   2.5      ]
 [29.277779   8.       ]
 [ 9.944445  17.5      ]
 [19.777779   7.       ]
 [22.166666  11.       ]
 [20.         8.5      ]
 [ 7.2222223 14.5      ]
 [15.944445   4.   

100%|██████████| 180/180 [01:08<00:00,  2.62it/s]


20
[[14 13 16 ... 13 12 13]
 [ 8  4 12 ...  4  4  4]
 [25 32 14 ... 33 29 32]
 ...
 [ 8  5  8 ...  6  5  5]
 [15 18  8 ... 18 17 18]
 [ 7  7 12 ...  7  8  7]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[25 32 14 ... 33 29 32]
 [18 24  9 ... 24 21 24]
 [18 20  7 ... 20 19 20]
 ...
 [25 29 14 ... 29 28 29]
 [14 17  8 ... 17 15 17]
 [15 18  8 ... 18 17 18]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.75it/s]


60
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
18
[ 0  1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
2
[ 2 10]
[[ 2.6666667 15.       ]
 [13.333333  10.       ]
 [ 8.        14.       ]
 [ 8.944445  19.       ]
 [16.11111    7.5      ]
 [26.88889    7.5      ]
 [20.944445   9.       ]
 [11.944445  14.5      ]
 [ 1.0555556 13.       ]
 [13.166667   8.       ]
 [27.944445  11.       ]
 [ 5.1666665 16.       ]
 [23.722221   9.5      ]
 [11.333333  18.5      ]
 [21.722221   6.       ]
 [ 8.5       19.       ]
 [21.         7.       ]
 [19.61111    8.       ]
 [12.722222  17.5      ]
 [15.333333   7.       ]
 [ 7.888889  18.       ]
 [ 3.1111112 14.5      ]
 [ 4.1666665 14.       ]
 [25.833334  12.       ]
 [ 3.8888888 14.       ]
 [ 8.5       20.       ]
 [13.166667   2.5      ]
 [31.         9.       ]
 [ 9.944445  17.5      ]
 [19.777779   7.       ]
 [19.555555  10.       ]
 [20.055555   8.5      ]
 [ 7.2222223 14.5      ]
 [15.944445   4.   

100%|██████████| 180/180 [01:14<00:00,  2.43it/s]


20
[[14 13 16 ... 13 12 13]
 [ 8  4 12 ...  4  4  4]
 [25 32 14 ... 32 29 33]
 ...
 [ 8  5  8 ...  6  5  5]
 [16 20  9 ... 20 18 20]
 [ 8  9 13 ...  9  9  9]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[25 32 14 ... 32 29 33]
 [18 23  9 ... 23 21 23]
 [19 22  8 ... 22 20 22]
 ...
 [22 28 12 ... 29 27 28]
 [14 16  8 ... 16 15 16]
 [16 20  9 ... 20 18 20]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:22<00:00,  2.64it/s]


60
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
18
[ 0  1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
2
[ 2 10]
[[ 2.6666667 15.       ]
 [13.333333  10.       ]
 [ 8.        14.       ]
 [ 8.944445  19.       ]
 [16.11111    7.5      ]
 [26.88889    7.5      ]
 [20.944445   9.       ]
 [11.944445  14.5      ]
 [ 1.0555556 13.       ]
 [13.166667   8.       ]
 [27.944445  11.       ]
 [ 5.1666665 16.       ]
 [23.722221   9.5      ]
 [11.333333  17.5      ]
 [21.722221   6.       ]
 [ 8.5       19.       ]
 [21.         7.       ]
 [19.61111    8.       ]
 [12.722222  17.5      ]
 [16.222221   7.       ]
 [ 7.888889  18.       ]
 [ 3.1111112 14.5      ]
 [ 4.1666665 14.       ]
 [25.833334  12.       ]
 [ 3.8888888 14.       ]
 [ 8.5       20.       ]
 [13.166667   2.5      ]
 [28.38889    8.       ]
 [ 9.944445  17.5      ]
 [19.777779   7.       ]
 [19.555555  10.       ]
 [20.055555   8.5      ]
 [ 7.2222223 14.5      ]
 [15.888889   3.5  

100%|██████████| 180/180 [01:36<00:00,  1.86it/s]


20
[[13 13 16 ... 13 11 13]
 [ 8  4 12 ...  4  4  4]
 [25 31 14 ... 31 28 32]
 ...
 [ 8  5  8 ...  6  5  5]
 [16 20  9 ... 20 18 20]
 [ 8  9 13 ...  9  9  9]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[25 31 14 ... 31 28 32]
 [19 26 10 ... 26 22 26]
 [19 22  8 ... 22 19 22]
 ...
 [23 29 12 ... 29 28 29]
 [14 16  8 ... 16 15 16]
 [16 20  9 ... 20 18 20]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.83it/s]

60
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
18
[ 0  1  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1]
2
[ 2 10]
[[ 2.6666667 15.       ]
 [13.333333  10.       ]
 [ 8.        14.       ]
 [ 8.944445  19.       ]
 [16.11111    7.5      ]
 [26.88889    7.5      ]
 [20.944445   9.       ]
 [11.944445  14.5      ]
 [ 1.0555556 13.       ]
 [13.111111   8.       ]
 [27.944445  11.       ]
 [ 5.1666665 16.       ]
 [23.722221   9.5      ]
 [11.333333  17.5      ]
 [21.722221   6.       ]
 [ 8.5       19.       ]
 [21.         7.       ]
 [19.61111    8.       ]
 [12.722222  17.5      ]
 [16.222221   7.       ]
 [ 7.888889  18.       ]
 [ 3.1111112 14.5      ]
 [ 4.1666665 14.       ]
 [25.833334  12.       ]
 [ 3.8888888 14.       ]
 [ 8.5       20.       ]
 [13.166667   2.5      ]
 [28.38889    8.       ]
 [ 9.944445  17.5      ]
 [19.777779   7.       ]
 [19.666666  10.       ]
 [20.055555   8.5      ]
 [ 7.2222223 14.5      ]
 [15.888889   3.5  




In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams["font.family"] = 'Times New Roman'
#plt.rcParams["font.family"] = 'Hiragino Mincho ProN'
plt.rcParams["font.size"] = 25

plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
# 最終プロット
plt.plot(range(n_epoch), accuracy_all, marker='o', label='Training Accuracy')
plt.plot(range(n_epoch), accuracy_test_all, marker='s', label='Validation Accuracy')
plt.xticks(range(n_epoch))
plt.ylim(0, 1)
plt.legend(fontsize=12)
plt.grid(True)
plt.show()