In [None]:
# -*- coding: utf-8 -*-

import numpy as np
from tqdm import tqdm
import csv
import os
import shutil
import pickle
import random
from sklearn.metrics import mean_squared_error

from numba.experimental import jitclass
from numba import f8, i8, jit


def lif(currents, th, time: int, lif_time: float = 1.0, rest=-65, ref=3, tc_decay=100):
    """ simple LIF neuron """
    time = int(time / lif_time)

    # initialize
    tlast = 0  # 最後に発火した時刻
    vpeak = 20  # 膜電位のピーク(最大値)
    spikes = np.zeros(time)
    v = rest  # 静止膜電位
    
    monitor = []  # monitor voltage

    # Core of LIF
    # 微分方程式をコーディングするときは，このように時間分解能dtで離散的に計算することが多い
    for t in range(time):
        dv = ((lif_time * t) > (tlast + ref)) * (-v + rest + currents[t]) / tc_decay  # 微小膜電位増加量
        v = v + lif_time * dv  # 膜電位を計算

        tlast = tlast + (lif_time * t - tlast) * (v >= th)  # 発火したら発火時刻を記録
        v = v + (vpeak - v) * (v >= th)  # 発火したら膜電位をピークへ

        monitor.append(v)

        spikes[t] = (v >= th) * 1  # スパイクをセット

        v = v + (rest - v) * (v >= th)  # 静止膜電位に戻す

    return spikes


def lif3(currents, th, time: int, lif_time: float = 1.0, rest=-65, ref=3, tc_decay=100):
    """ Reverse LIF neuron """
    time = int(time / lif_time)

    # initialize
    tlast = 0  # 最後に発火した時刻
    vpeak = 20  # 膜電位のピーク(最大値)
    spikes = np.zeros(time)
    v = rest  # 静止膜電位
    max_current = 1000  # 最大電流の調整パラメータ
    currents_new = max_current / (np.abs(currents) + 0.000000001) # 0除算を防ぐための小さい値

    monitor = []  
    for t in range(time):
        dv = ((lif_time * t) > (tlast + ref)) * (-v + rest + currents_new[t]) / tc_decay
          # 微小膜電位増加量
        v = v + lif_time * dv  # 膜電位を計算

        tlast = tlast + (lif_time * t - tlast) * (v >= th)  # 発火したら発火時刻を記録
        v = v + (vpeak - v) * (v >= th)  # 発火したら膜電位をピークへ

        monitor.append(v)

        spikes[t] = (v >= th) * 1  # スパイクをセット

        v = v + (rest - v) * (v >= th)  # 静止膜電位に戻す

    return spikes


def poisson1(currents, time: int, dt: float = 1.0):
    """ Poisson encoding """
    # スパイクレートの計算（正規化なし）
    time = int(time / dt)
    spikes = np.zeros(time)
    spike_rates = np.abs(currents)  # スパイクレートを直接振幅に依存させる

# スパイク列を生成
    for i, rate in enumerate(spike_rates):
        if np.random.poisson(rate * dt / 1000) > 1:
            spikes[i] = 1
    

    return spikes


def poisson2(currents, time: int, dt: float = 1.0):
    """ reverse poisson encoding """
    # スパイクレートの計算（正規化なし）
    time = int(time / dt)
    spikes = np.zeros(time)
    spike_rates = np.abs(currents)
    max_spike_rate = 100  # 最大スパイクレートの設定
    spike_rates = max_spike_rate / (spike_rates + 0.00000001)  # スパイクレートを直接振幅に依存させる

# スパイク列を生成
    for i, rate in enumerate(spike_rates):
        if np.random.poisson(rate * dt / 1000) > 0.5:
            spikes[i] = 1

    return spikes


def poisson3(currents, time: int, dt: float = 1.0):
    """ reverse poisson encoding """
    # スパイクレートの計算（正規化なし）
    time = int(time / dt)
    spikes = np.zeros(time)
    spike_rates = np.abs(currents)
    max_spike_rate = 100000  # 最大スパイクレートの設定
    spike_rates = max_spike_rate / (spike_rates + 0.000001)  # スパイクレートを直接振幅に依存させる

# スパイク列を生成
    for i, rate in enumerate(spike_rates):
        if np.random.poisson(rate * dt / 100000) > 1:
            spikes[i] = 1
    

    return spikes


@jitclass([('N',i8),('dt',f8),('td',f8),('r',f8[:])])
class SingleExponentialSynapse:
    def __init__(self, N, dt=1e-4, td=5e-3):
        """
        Args:
            td (float):Synaptic decay time
        """
        self.N = N
        self.dt = dt
        self.td = td
        self.r = np.zeros(N)

    def initialize_states(self):
        self.r = np.zeros(self.N)

    def func(self, spike):
        r = self.r*(1-self.dt/self.td) + spike/self.td
        self.r = r
        return r

    #def __call__(self, spike):
    #    r = self.r*(1-self.dt/self.td) + spike/self.td
    #    self.r = r
    #    return r

@jitclass([('N',i8),('dt',f8),('tref',f8),('tc_m',f8),('vrest',i8),('vreset',i8),('vthr',i8),('vpeak',i8),('e_exc',i8),('e_inh',i8),('v',f8[:]),('v_',f8[:]),('tlast',f8),('tcount',i8)])
class ConductanceBasedLIF:
    def __init__(self, N, dt=1e-4, tref=5e-3, tc_m=1e-2,
                 vrest=-60, vreset=-60, vthr=-50, vpeak=20,
                 e_exc=0, e_inh=-100):
        """
        Conductance-based Leaky integrate-and-fire model.

        Args:
            N (int)       : Number of neurons.
            dt (float)    : Simulation time step in seconds.
            tc_m (float)  : Membrane time constant in seconds.
            tref (float)  : Refractory time constant in seconds.
            vreset (float): Reset membrane potential (mV).
            vrest (float) : Resting membrane potential (mV).
            vthr (float)  : Threshold membrane potential (mV).
            vpeak (float) : Peak membrane potential (mV).
            e_exc (float) : equilibrium potential of excitatory synapses (mV).
            e_inh (float) : equilibrium potential of inhibitory synapses (mV).
        """
        self.N = N
        self.dt = dt
        self.tref = tref
        self.tc_m = tc_m
        self.vrest = vrest
        self.vreset = vreset
        self.vthr = vthr
        self.vpeak = vpeak

        self.e_exc = e_exc # 興奮性シナプスの平衡電位
        self.e_inh = e_inh # 抑制性シナプスの平衡電位

        self.v = self.vreset*np.ones(N)
        self.v_ = np.zeros(N)  #change
        self.tlast = 0
        self.tcount = 0

    def initialize_states(self, random_state=False):
        if random_state:
            self.v = self.vreset + np.random.rand(self.N)*(self.vthr-self.vreset)
        else:
            self.v = self.vreset*np.ones(self.N)
        self.tlast = 0
        self.tcount = 0

    def func(self, g_exc, g_inh):
        I_synExc = g_exc*(self.e_exc - self.v)
        I_synInh = g_inh*(self.e_inh - self.v)
        dv = (self.vrest - self.v + I_synExc + I_synInh) / self.tc_m #Voltage equation with refractory period
        v = self.v + ((self.dt*self.tcount) > (self.tlast + self.tref))*dv*self.dt

        s = 1*(v>=self.vthr) #発火時は1, その他は0の出力
        self.tlast = self.tlast*(1-s) + self.dt*self.tcount*s #最後の発火時の更新
        v = v*(1-s) + self.vpeak*s #閾値を超えると膜電位をvpeakにする
        self.v_ = v #発火時の電位も含めて記録するための変数
        self.v = v*(1-s) + self.vreset*s  #発火時に膜電位をリセット
        self.tcount += 1

        return s


@jitclass([('N',i8),('dt',f8),('tref',f8),('tc_m',f8),('vrest',i8),('vreset',i8),('vrest',i8),('vreset',i8),('init_vthr',i8),('vpeak',i8),('theta_plus',f8),('theta_max',i8),('tc_theta',f8),('vthr',i8),('vpeak',i8),('e_exc',i8),('e_inh',i8),('v',f8[:]),('theta',f8[:]),('v_',f8[:]),('tlast',f8),('tcount',i8)])
class DiehlAndCook2015LIF:
    def __init__(self, N, dt=1e-3, tref=5e-3, tc_m=1e-1,
                 vrest=-65, vreset=-65, init_vthr=-52, vpeak=20,
                 theta_plus=0.025, theta_max=35, tc_theta=1e4, e_exc=0, e_inh=-100):
        """
        Leaky integrate-and-fire model of Diehl and Cooks (2015)
        https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full

        Args:
            N (int)       : Number of neurons.
            dt (float)    : Simulation time step in seconds.
            tc_m (float)  : Membrane time constant in seconds.
            tref (float)  : Refractory time constant in seconds.
            vreset (float): Reset membrane potential (mV).
            vrest (float) : Resting membrane potential (mV).
            vthr (float)  : Threshold membrane potential (mV).
            vpeak (float) : Peak membrane potential (mV).
            e_exc (float) : equilibrium potential of excitatory synapses (mV).
            e_inh (float) : equilibrium potential of inhibitory synapses (mV).
        """
        self.N = N
        self.dt = dt
        self.tref = tref
        self.tc_m = tc_m
        self.vreset = vreset
        self.vrest = vrest
        self.init_vthr = init_vthr
        self.theta = np.zeros(N)
        self.theta_plus = theta_plus
        self.theta_max = theta_max
        self.tc_theta = tc_theta
        self.vpeak = vpeak

        self.e_exc = e_exc # 興奮性シナプスの平衡電位
        self.e_inh = e_inh # 抑制性シナプスの平衡電位

        self.v = self.vreset*np.ones(N)
        self.vthr = self.init_vthr
        self.v_ = np.zeros(N)  #change
        self.tlast = 0
        self.tcount = 0

    def initialize_states(self, random_state=False):
        if random_state:
            self.v = self.vreset + np.random.rand(self.N)*(self.vthr-self.vreset)
        else:
            self.v = self.vreset*np.ones(self.N)
        self.vthr = self.init_vthr
        self.theta = np.zeros(self.N)
        self.tlast = 0
        self.tcount = 0

    def func(self, g_exc, g_inh):
        I_synExc = g_exc*(self.e_exc - self.v)
        I_synInh = g_inh*(self.e_inh - self.v)
        dv = (self.vrest - self.v + I_synExc + I_synInh) / self.tc_m #Voltage equation with refractory period
        v = self.v + ((self.dt*self.tcount) > (self.tlast + self.tref))*dv*self.dt

        s = 1*(v>=self.vthr) #発火時は1, その他は0の出力
        theta = (1-self.dt/self.tc_theta)*self.theta + self.theta_plus*s
        self.theta = np.clip(theta, 0, self.theta_max)
        self.vthr = self.theta + self.init_vthr
        self.tlast = self.tlast*(1-s) + self.dt*self.tcount*s #最後の発火時の更新
        v = v*(1-s) + self.vpeak*s #閾値を超えると膜電位をvpeakにする
        self.v_ = v #発火時の電位も含めて記録するための変数
        self.v = v*(1-s) + self.vreset*s  #発火時に膜電位をリセット
        self.tcount += 1

        return s

#@jitclass([('W',f8[:,:])])
class FullConnection:
    def __init__(self, N_in, N_out, initW=None):
        """
        FullConnection
        """
        if initW is not None:
            self.W = initW
        else:
            self.W = 0.1*np.random.rand(N_out, N_in)

    def backward(self, x):
        return np.dot(self.W.T, x) #self.W.T @ x

    def func(self, x):
        return np.dot(self.W, x) #self.W @ x


@jitclass([('N',i8),('nt_delay',i8),('state',f8[:,:])])
class DelayConnection:
    def __init__(self, N, delay, dt=1e-4):
        """
        Args:
            delay (float): Delay time
        """
        self.N = N
        self.nt_delay = round(delay/dt) # 遅延のステップ数
        self.state = np.zeros((N, self.nt_delay))

    def initialize_states(self):
        self.state = np.zeros((self.N, self.nt_delay))

    def func(self, x):
        out = self.state[:, -1] # 出力

        self.state[:, 1:] = self.state[:, :-1] # 配列をずらす
        self.state[:, 0] = x # 入力

        return out

np.random.seed(seed=0)

para_dic= {'w_exc': 4.544910490280396, 'w_inh': 1.3082028299035284, 'lr1': 0.0312611326044722, 'lr2': 0.03592142258355323, 'Norm': 0.22856140170464153}
para_list=list(para_dic.values())

k=7


# ラベルの割り当て
def assign_labels(spikes, labels, unique_labels, rates=None, alpha=1.0):
    """
    Assign labels to the neurons based on highest average spiking activity.

    Args:
        spikes (n_samples, n_neurons) : A single layer's spiking activity.
        labels (n_samples,) : Data labels corresponding to input samples. 480labelで1,2,3,..のどれか
        n_labels (int)      : The number of target labels in the data.  
        rates (n_neurons, n_labels) : If passed, these represent spike rates
                                      from a previous ``assign_labels()`` call.
        alpha (float): Rate of decay of label assignments.
    return: Class assignments, per-class spike proportions, and per-class firing rates.
    """
    n_neurons = spikes.shape[1]
    print(n_neurons)
    print(spikes)

    if rates is None:
        n_labels = len(unique_labels)
        rates = np.zeros((n_neurons, n_labels)).astype(np.float32)

    # 時間の軸でスパイク数の和を取る
    for i, label in enumerate(unique_labels):
        # サンプル内の同じラベルの数を求める
        n_labeled = np.sum(labels == label).astype(np.int16)
        

        if n_labeled > 0:
            # label == iのサンプルのインデックスを取得
            indices = np.where(labels == label)[0]
            print(indices)
            print(spikes[indices])

            # label == iに対する各ニューロンごとの平均発火率を計算(前回の発火率との移動平均)
            rates[:, i] = alpha*rates[:, i] + (np.sum(spikes[indices], axis=0)/n_labeled)

    sum_rate = np.sum(rates, axis=1)
    sum_rate[sum_rate==0] = 1
    # クラスごとの発火頻度の割合を計算する
    proportions = rates / np.expand_dims(sum_rate, 1) # (n_neurons, n_labels)
    proportions[proportions != proportions] = 0  # Set NaNs to 0

    # 最も発火率が高いラベルを各ニューロンに割り当てる
    mapped_labels = np.argmax(proportions, axis=1).astype(np.uint8) # (n_neurons,)
    assignments = np.array(unique_labels)[mapped_labels]
    return assignments, proportions, rates

# assign_labelsで割り当てたラベルからサンプルのラベルの予測をする
def prediction(spikes, assignments, unique_labels):
    """
    Classify data with the label with highest average spiking activity over all neurons.

    Args:
        spikes  (n_samples, n_neurons) : A layer's spiking activity.
        assignments (n_neurons,) : Neuron label assignments.
        n_labels (int): The number of target labels in the data.
    return: Predictions (n_samples,)
    """

    n_samples = spikes.shape[0]
    print(n_samples)
    n_labels = len(unique_labels)

    # 各サンプルについて各ラベルの発火率を見る
    rates = np.zeros((n_samples, n_labels)).astype(np.float32)

    for i, label in enumerate(unique_labels):
        # 各ラベルが振り分けられたニューロンの数
        print(assignments)
        n_assigns = np.sum(assignments == label).astype(np.uint8)
        print(n_assigns)

        if n_assigns > 0:
            # 各ラベルのニューロンのインデックスを取得
            indices = np.where(assignments == label)[0]
            print(indices)

            # 各ラベルのニューロンのレイヤー全体における平均発火数を求める
            rates[:, i] = np.sum(spikes[:, indices], axis=1) / n_assigns
    print(rates)
    predicted_indices = np.argmax(rates, axis=1)
    max_rates = np.max(rates, axis=1)
    for i in range(n_samples):
        if np.sum(rates[i, :] == max_rates[i]) > 1:
            predicted_indices[i] = 2  # 複数のラベルが同じ最大値を持つ場合は2に設定
    print(predicted_indices)
    predicted_labels = np.array(correct_lables)[predicted_indices]

    # レイヤーの平均発火率が最も高いラベルを出力
    return predicted_labels.astype(np.uint8)# (n_samples, )


#################
####  Model  ####
#################

class DiehlAndCook2015Network:
    #exc_neurons: exc_neurons_type
    #inh_neurons: inh_neurons_type
    #input_synapse:input_synapse_type
    #exc_synapse:exc_synapse_type
    #inh_synapse: inh_synapse_type
    #input_synaptictrace:  input_synaptictrace_type
    #exc_synaptictrace:exc_synaptictrace_type
    #input_conn:input_conn_type 
    #delay_input :delay_input_type 
    #delay_exc2inh:delay_exc2inh_type
    
    def __init__(self, n_in=4, n_neurons=100, wexc=2.25, winh=0.875,
                 dt=1e-3, wmin=0.0, wmax=5e-2, lr=(1e-2, 1e-4),
                 update_nt=100, norm = 0.1):
        """
        Network of Diehl and Cooks (2015)
        https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full

        Args:
            n_in: Number of input neurons. Matches the 1D size of the input data.
            n_neurons: Number of excitatory, inhibitory neurons.
            wexc: Strength of synapse weights from excitatory to inhibitory layer.
            winh: Strength of synapse weights from inhibitory to excitatory layer.
            dt: Simulation time step.
            lr: Single or pair of learning rates for pre- and post-synaptic events, respectively.
            wmin: Minimum allowed weight on input to excitatory synapses.
            wmax: Maximum allowed weight on input to excitatory synapses.
            update_nt: Number of time steps of weight updates.
        """

        self.dt = dt
        self.lr_p, self.lr_m = lr
        self.wmax = wmax
        self.wmin = wmin
        print(self.lr_p)
        print(self.lr_m)
        print(self.wmin)
        print(self.wmax)


        # Neurons
        self.exc_neurons = DiehlAndCook2015LIF(n_neurons, dt=dt, tref=5e-3,
                                               tc_m=1e-1,
                                               vrest=-65, vreset=-65,
                                               init_vthr=-52,
                                               vpeak=20, theta_plus=0.05,
                                               theta_max=35,
                                               tc_theta=1e4,
                                               e_exc=0, e_inh=-100)

        self.inh_neurons = ConductanceBasedLIF(n_neurons, dt=dt, tref=2e-3,
                                               tc_m=1e-2,
                                               vrest=-60, vreset=-45,
                                               vthr=-40, vpeak=20,
                                               e_exc=0, e_inh=-85)
        # Synapses
        self.input_synapse = SingleExponentialSynapse(n_in, dt=dt, td=1e-3)
        self.exc_synapse = SingleExponentialSynapse(n_neurons, dt=dt, td=1e-3)
        self.inh_synapse = SingleExponentialSynapse(n_neurons, dt=dt, td=2e-3)

        self.input_synaptictrace = SingleExponentialSynapse(n_in, dt=dt,
                                                            td=2e-2)
        self.exc_synaptictrace = SingleExponentialSynapse(n_neurons, dt=dt,
                                                          td=2e-2)

        # Connections
        initW = 1e-3*np.random.rand(n_neurons, n_in)
        self.input_conn = FullConnection(n_in, n_neurons,
                                         initW=initW)
        self.exc2inh_W = wexc*np.eye(n_neurons)
        self.inh2exc_W = (winh/(n_neurons-1))*(np.ones((n_neurons, n_neurons)) - np.eye(n_neurons))

        self.delay_input = DelayConnection(N=n_neurons, delay=5e-3, dt=dt)
        self.delay_exc2inh = DelayConnection(N=n_neurons, delay=2e-3, dt=dt)

        self.norm = norm
        self.g_inh = np.zeros(n_neurons)
        self.tcount = 0
        self.update_nt = update_nt
        self.n_neurons = n_neurons
        self.n_in = n_in
        self.s_in_ = np.zeros((self.update_nt, n_in))
        self.s_exc_ = np.zeros((n_neurons, self.update_nt))
        self.x_in_ = np.zeros((self.update_nt, n_in))
        self.x_exc_ = np.zeros((n_neurons, self.update_nt))

    # スパイクトレースのリセット
    def reset_trace(self):
        self.s_in_ = np.zeros((self.update_nt, self.n_in))
        self.s_exc_ = np.zeros((self.n_neurons, self.update_nt))
        self.x_in_ = np.zeros((self.update_nt, self.n_in))
        self.x_exc_ = np.zeros((self.n_neurons, self.update_nt))
        self.tcount = 0

    # 状態の初期化
    def initialize_states(self):
        self.exc_neurons.initialize_states()
        self.inh_neurons.initialize_states()
        self.delay_input.initialize_states()
        self.delay_exc2inh.initialize_states()
        self.input_synapse.initialize_states()
        self.exc_synapse.initialize_states()
        self.inh_synapse.initialize_states()

    def __call__(self, s_in, stdp=True):
        # 入力層
        c_in = self.input_synapse.func(s_in) # シナプス１6こ
        x_in = self.input_synaptictrace.func(s_in)  # シナプス２　６こ
        g_in = self.input_conn.func(c_in)  # 重み畳み込み

        # 興奮性ニューロン層
        s_exc = self.exc_neurons.func(self.delay_input.func(g_in), self.g_inh)
        c_exc = self.exc_synapse.func(s_exc) # シナプス3 10こ
        g_exc = np.dot(self.exc2inh_W, c_exc) 
        x_exc = self.exc_synaptictrace.func(s_exc) # シナプス4 10こ

        # 抑制性ニューロン層
        s_inh = self.inh_neurons.func(self.delay_exc2inh.func(g_exc), 0)
        c_inh = self.inh_synapse.func(s_inh) # シナプス5 10こ
        self.g_inh = np.dot(self.inh2exc_W, c_inh)

        if stdp:
            # スパイク列とスパイクトレースを記録
            self.s_in_[self.tcount] = s_in
            self.s_exc_[:, self.tcount] = s_exc
            self.x_in_[self.tcount] = x_in
            self.x_exc_[:, self.tcount] = x_exc
            self.tcount += 1

            # Online STDP
            if self.tcount == self.update_nt: # 10000回溜まったら　１エポック１回
                W = np.copy(self.input_conn.W)

                # postに投射される重みが均一になるようにする
                W_abs_sum = np.expand_dims(np.sum(np.abs(W), axis=1), 1)
                W_abs_sum[W_abs_sum == 0] = 1.0
                W *= self.norm / W_abs_sum

                # STDP則
                dW = self.lr_p*(self.wmax - W)*np.dot(self.s_exc_, self.x_in_)
                dW -= self.lr_m*W*np.dot(self.x_exc_, self.s_in_)
                clipped_dW = np.clip(dW / self.update_nt, -1e-3, 1e-3)
                self.input_conn.W = np.clip(W + clipped_dW,
                                            self.wmin, self.wmax)
                self.reset_trace() # スパイク列とスパイクトレースをリセット
                

        return s_exc


# 350ms画像入力、150ms入力なしでリセットさせる(膜電位の閾値以外)
dt = 2e-4 # タイムステップ(sec)
t_inj = 2.0 # 刺激入力時間(sec)
t_blank = 0.5 # ブランク時間(sec)
nt_inj = round(t_inj/dt)
nt_blank = round(t_blank/dt)

n_neurons = 20 #興奮性/抑制性ニューロンの数
unique_labels = [1, 2]
correct_lables = [1,2,0] #ラベル数
n_epoch = 10 #エポック数

n_train = 180 # 訓練データの数
n_val = 60
n_test = 100
update_nt = nt_inj # STDP則による重みの更新間隔



if __name__ == '__main__':
    duration = 2000  # ms
    lif_time = 0.2  # time step




    data_set = []
    t_set = []


    for l in [1, 2]:
        for i in range(31, 121):
            virtual_list = np.zeros([10000, 6])  # スパイクデータ格納用配列を初期化

            # チャンネル0
            input_data1 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes1 = lif(input_data1, -45, duration, lif_time)
            virtual_list[:, 0] = np.asarray(spikes1, dtype=int)

            # チャンネル1
            input_data2 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes2 = lif3(input_data2, -40, duration, lif_time)
            virtual_list[:, 1] = np.asarray(spikes2, dtype=int)

            # チャンネル2
            input_data3 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes3 = poisson1(input_data3, duration, lif_time)
            virtual_list[:, 2] = np.asarray(spikes3, dtype=int)

            # チャンネル3
            input_data4 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes4 = poisson2(input_data4, duration, lif_time)
            virtual_list[:, 3] = np.asarray(spikes4, dtype=int)

            # チャンネル4
            input_data5 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes5 = lif(input_data5, -58, duration, lif_time)
            virtual_list[:, 4] = np.asarray(spikes5, dtype=int)

            # チャンネル5
            input_data6 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes6 = poisson3(input_data6, duration, lif_time)
            virtual_list[:, 5] = np.asarray(spikes6, dtype=int)
            if l == 1:
                data_set.append((virtual_list, l))
            elif l == 2:
                data_set.append((virtual_list, l))

    
    for l in [1, 2]:
        for i in range(1, 31):
            virtual_list = np.zeros([10000, 6])  # スパイクデータ格納用配列を初期化

            # チャンネル0
            input_data7 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes7 = lif(input_data7, -45, duration, lif_time)
            virtual_list[:, 0] = np.asarray(spikes7, dtype=int)

            # チャンネル1
            input_data8 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes8 = lif3(input_data8, -40, duration, lif_time)
            virtual_list[:, 1] = np.asarray(spikes8, dtype=int)

            # チャンネル2
            input_data9 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes9 = poisson1(input_data9,  duration, lif_time)
            virtual_list[:, 2] = np.asarray(spikes9, dtype=int)

            # チャンネル3
            input_data10 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes10 = poisson2(input_data10, duration, lif_time)
            virtual_list[:, 3] = np.asarray(spikes10, dtype=int)

            # チャンネル4
            input_data11 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes11 = lif(input_data11, -58, duration, lif_time)
            virtual_list[:, 4] = np.asarray(spikes11, dtype=int)

            # チャンネル5
            input_data12 = 100 * np.loadtxt(r"C:\Users\yshou\yasapy\Vibrationdata\sample{0}\sample{1}_{2}.txt".format(l, l, i), usecols=[0], dtype='float')
            spikes12 = poisson3(input_data12, duration, lif_time)
            virtual_list[:, 5] = np.asarray(spikes12, dtype=int)
            if l == 1:
                t_set.append((virtual_list, l))
            elif l == 2:
                t_set.append((virtual_list, l))
    
    random.seed(0)
    spikes_list = random.sample(data_set,len(data_set))
    with open('spikes_list_me.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        for spik in spikes_list:
            writer.writerow([np.sum(spik[0], axis=0)] + [spik[1]])
    random.seed(0)
    spikes_list_val = random.sample(t_set,len(t_set))

    count1 = np.sum(spikes_list[1][0] == 0)
    count2 = np.sum(spikes_list[2][0] == 1)
    count3 = np.sum(spikes_list[3][0] == 1)

    print(count1)
    print(count2)
    print(count3)


    labels = np.array([spikes_list[i][1] for i in range(n_train)]) # ラベルの配列
    test_labels = np.array([spikes_list_val[i][1] for i in range(n_val)])
    results_save_dir = "./redclassiffication_results/" # 結果を保存するディレクトリ
    os.makedirs(results_save_dir, exist_ok=True) # ディレクトリ作成
    
    network = DiehlAndCook2015Network(n_in=6, n_neurons=n_neurons,
                                      wexc=para_list[0], winh=para_list[1],
                                      dt=dt, wmin=0.0, wmax=10.0,
                                      lr=(para_list[2], para_list[3]),
                                      update_nt=update_nt,norm = para_list[4])
    
    network.initialize_states() 
    accuracy_all = np.zeros(n_epoch)
    accuracy_test_all = np.zeros(n_epoch) # 訓練精度を記録する変数
    spikes= np.zeros((n_train, n_neurons)).astype(np.uint8)
    print(spikes)
    blank_input = np.zeros(6) # ブランク入力
    print(spikes_list_val[5][1]) #(10000, 28)  
    rows = len(spikes_list_val[5][0])
    print(rows)

    """28はニューロンの数,10000はスパイクの記録"""

    for epoch in range(n_epoch):
        for i in tqdm(range(n_train)):
            #max_fr = init_max_fr
            while(True):
                # 入力スパイクをオンラインで生成
                input_spikes = spikes_list[i][0] #input_spikeは10000*6
                spike_list = [] # サンプルごとにスパイクを記録するリスト
                # 画像刺激の入力
                for t in range(nt_inj):
                    s_exc = network(input_spikes[t], stdp=True)
                    spike_list.append(s_exc)

                spikes[i] = np.sum(np.array(spike_list), axis=0) 

                # ブランク刺激の入力
                for _ in range(nt_blank):
                    _ = network(blank_input, stdp=False)

                num_spikes_exc = np.sum(np.array(spike_list)) # スパイク数を計算
                if num_spikes_exc >= 0: # スパイク数が5より大きければ次のサンプルへ
                    break
                #else: # スパイク数が5より小さければ入力発火率を上げて再度刺激
                #    max_fr += 16
        


        # ニューロンを各ラベルに割り当てる
        if epoch == 0:
            assignments, proportions, rates = assign_labels(spikes, labels,
                                                            unique_labels)
            print(assignments)
            print(proportions)
            print(rates)

        else:
            assignments, proportions, rates = assign_labels(spikes, labels,
                                                            unique_labels, rates)
            print(rates)
        print("Assignments:\n", assignments)
        print(labels)

        # スパイク数の確認(正常に発火しているか確認)
        sum_nspikes = np.sum(spikes, axis=1)
        mean_nspikes = np.mean(sum_nspikes).astype(np.float16)
        print("Ave. spikes:", mean_nspikes)
        print("Min. spikes:", sum_nspikes.min())
        print("Max. spikes:", sum_nspikes.max())

        # 入力サンプルのラベルを予測する
        predicted_labels = prediction(spikes, assignments, unique_labels)
        print('prediction:\n',predicted_labels)

        # 訓練精度を計算
        accuracy = np.mean(np.where(labels==predicted_labels, 1, 0)).astype(np.float16)
        print("epoch :", epoch, " accuracy :", accuracy)
        print("wights",network.input_conn.W)
        if accuracy < 0.5:
            #score_list.append(5.0)
            break
        accuracy_all[epoch] = accuracy

        #return accuracy_all[-1]
        # 学習率の減衰
        network.lr_p *= 0.25
        network.lr_m *= 0.25


        # 重みの保存(エポック毎)
        np.save(results_save_dir+"weight_epoch"+str(epoch)+".npy", network.input_conn.W)
        np.save(results_save_dir+"assignment_epoch"+str(epoch)+".npy", assignments)
        np.save(results_save_dir+"exc_neurons_epoch"+str(epoch)+".npy", network.exc_neurons.theta)


        #test phase
        network_test = DiehlAndCook2015Network(n_in=6, n_neurons=n_neurons, dt=dt)
        network_test.initialize_states()

        #network_test.input_conn.W = np.load(results_save_dir+"weight_epoch"+str(epoch)+".npy")
        #network_test.exc_neurons.theta = np.load(results_save_dir+"exc_neurons_epoch"+str(epoch)+".npy")
        network_test.input_conn.W = network.input_conn.W
        network_test.exc_neurons.theta = network.exc_neurons.theta
        network_test.exc_neurons.theta_plus = 0 # 閾値が上昇しないようにする
        
        spikes_val = np.zeros((n_val, n_neurons)).astype(np.uint8)
        blank_input_test = np.zeros(6) # ブランク入力


        for i in tqdm(range(n_val)):

            while(True):
                # 入力スパイクをオンラインで生成
                input_spikes_val = spikes_list_val[i][0]
                spike_list = [] # サンプルごとにスパイクを記録するリスト
                # 画像刺激の入力
                for t in range(nt_inj):
                    s_exc_test = network_test(input_spikes_val[t], stdp=False)
                    spike_list.append(s_exc_test)

                spikes_val[i] = np.sum(np.array(spike_list), axis=0) # スパイク数を記録

                # ブランク刺激の入力
                for _ in range(nt_blank):
                    _ = network_test(blank_input_test, stdp=False)

                num_spikes_exc = np.sum(np.array(spike_list)) # スパイク数を計算

                if num_spikes_exc >= 0: # スパイク数が5より大きければ次のサンプルへ
                    break
                #else: # スパイク数が5より小さければ入力発火率を上げて再度刺激
                #    max_fr += 16

        # 入力サンプルのラベルを予測する
        #assignments = np.load(results_save_dir+"assignment_epoch"+str(epoch)+".npy")
        assignments = assignments
        predicted_labels_val = prediction(spikes_val, assignments, unique_labels)
        print('Val prediction:\n',predicted_labels_val)

        # 訓練精度を計算
        accuracy_test = np.mean(np.where(test_labels==predicted_labels_val, 1, 0)).astype(np.float16)
        print('Val accuracy : ',accuracy_test)
        accuracy_test_all[epoch] = accuracy_test
    print(accuracy_all)
    print(accuracy_test)






59923
76
92
0.0312611326044722
0.03592142258355323
0.0
10.0
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
1
10000


100%|██████████| 180/180 [01:17<00:00,  2.34it/s]


20
[[ 0  0  0 ...  0  0  0]
 [31 33 27 ... 19 13 14]
 [39 13  5 ... 10  1  2]
 ...
 [ 7  7  8 ...  7  7  7]
 [ 9  4  6 ...  7  4  7]
 [ 7  6  5 ...  6  6  5]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[39 13  5 ... 10  1  2]
 [39 17  7 ... 19  5  9]
 [28  6  6 ... 11  2  2]
 ...
 [ 8  3  4 ...  4  3  4]
 [ 6  3  5 ...  5  3  6]
 [ 9  4  6 ...  7  4  7]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:22<00:00,  2.61it/s]


60
[2 2 2 2 2 1 2 2 2 2 1 1 2 2 2 2 1 2 2 2]
4
[ 5 10 11 16]
[2 2 2 2 2 1 2 2 2 2 1 1 2 2 2 2 1 2 2 2]
16
[ 0  1  2  3  4  6  7  8  9 12 13 14 15 17 18 19]
[[ 2.25    5.375 ]
 [ 5.75    3.0625]
 [ 7.75    9.875 ]
 [ 4.75    7.0625]
 [ 7.5     6.0625]
 [10.5     6.375 ]
 [ 8.75    7.6875]
 [ 4.25    9.75  ]
 [ 5.75    8.25  ]
 [ 5.25    3.625 ]
 [ 6.25    7.4375]
 [ 6.     11.875 ]
 [ 8.5     7.125 ]
 [ 5.25   10.1875]
 [ 6.25    6.    ]
 [ 3.5     3.375 ]
 [ 9.5     6.0625]
 [ 6.5     2.875 ]
 [ 5.25   12.0625]
 [10.      5.875 ]
 [ 5.25   12.    ]
 [ 5.5    10.4375]
 [ 2.25    4.625 ]
 [ 5.75    4.75  ]
 [ 4.75    9.0625]
 [ 3.75   11.625 ]
 [ 9.      6.75  ]
 [ 7.      7.0625]
 [ 5.25    9.    ]
 [ 4.25    2.6875]
 [10.      5.9375]
 [ 9.75    4.3125]
 [ 5.25    8.125 ]
 [ 6.      3.0625]
 [ 2.75    4.125 ]
 [ 4.5     4.75  ]
 [ 6.5    12.3125]
 [ 7.      9.875 ]
 [ 6.      1.75  ]
 [ 6.5     3.5   ]
 [ 6.5    11.125 ]
 [ 4.5     8.125 ]
 [ 9.25    3.3125]
 [10.25    3.8125]
 [ 3.5  

100%|██████████| 180/180 [01:13<00:00,  2.43it/s]


20
[[12 14 14 ... 14 15 13]
 [11 11 11 ... 12 13 12]
 [ 7  5  7 ...  7  4  7]
 ...
 [13 17 18 ... 18 17 18]
 [ 8  4  4 ...  4  4  4]
 [ 7  7  7 ...  7  7  7]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[ 7  5  7 ...  7  4  7]
 [12  3  6 ...  7  3  7]
 [ 5  2  3 ...  3  2  3]
 ...
 [ 7  8  8 ...  8  8  8]
 [ 4  3  3 ...  3  3  3]
 [ 8  4  4 ...  4  4  4]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 15

100%|██████████| 60/60 [00:21<00:00,  2.75it/s]


60
[2 2 2 2 2 2 2 2 2 2 1 1 2 2 2 2 1 2 2 2]
3
[10 11 16]
[2 2 2 2 2 2 2 2 2 2 1 1 2 2 2 2 1 2 2 2]
17
[ 0  1  2  3  4  5  6  7  8  9 12 13 14 15 17 18 19]
[[ 4.6666665  7.       ]
 [ 5.3333335  2.1764705]
 [ 7.6666665 13.411765 ]
 [ 7.         9.941176 ]
 [ 7.         7.352941 ]
 [ 9.666667   6.2352943]
 [ 6.6666665  5.1764708]
 [ 7.6666665 11.       ]
 [ 6.3333335 10.       ]
 [ 6.3333335  4.9411764]
 [ 7.6666665  7.117647 ]
 [ 7.3333335 14.411765 ]
 [ 9.333333   7.117647 ]
 [ 7.        11.294118 ]
 [ 8.         6.1764708]
 [ 2.6666667  4.117647 ]
 [ 8.666667   3.3529413]
 [ 6.         3.4705882]
 [ 7.3333335 12.941176 ]
 [ 8.333333   5.       ]
 [ 8.333333  13.941176 ]
 [ 9.        11.       ]
 [ 4.3333335  6.0588236]
 [ 4.3333335  4.117647 ]
 [ 8.        11.058824 ]
 [ 7.6666665 15.411765 ]
 [ 8.666667   3.4117646]
 [ 9.         9.882353 ]
 [ 6.6666665 11.       ]
 [ 4.6666665  2.1176472]
 [ 8.         4.1764708]
 [ 7.         5.117647 ]
 [ 5.6666665 11.823529 ]
 [ 5.         1.235

100%|██████████| 180/180 [01:11<00:00,  2.51it/s]


20
[[14 16 17 ... 17 16 17]
 [18 21 21 ... 21 21 20]
 [ 5  3  3 ...  3  3  3]
 ...
 [14 14 14 ... 14 14 14]
 [ 3  3  3 ...  3  3  3]
 [ 8  8  8 ...  8  8  8]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[5 3 3 ... 3 3 3]
 [7 3 3 ... 3 3 3]
 [3 5 5 ... 5 5 5]
 ...
 [4 4 4 ... 4 4 4]
 [3 3 3 ... 3 3 3]
 [3 3 3 ... 3 3 3]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 151 153 162 163 165 167 168 170 171 17

100%|██████████| 60/60 [00:22<00:00,  2.69it/s]


60
[2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 1 2 2 2]
2
[10 16]
[2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 1 2 2 2]
18
[ 0  1  2  3  4  5  6  7  8  9 11 12 13 14 15 17 18 19]
[[ 1.5        7.6666665]
 [ 7.         2.8888888]
 [ 1.5       16.333334 ]
 [ 3.5        9.666667 ]
 [ 6.         5.7222223]
 [10.5        8.666667 ]
 [ 7.5        4.888889 ]
 [ 1.        10.722222 ]
 [ 1.         8.555555 ]
 [ 6.5        4.8333335]
 [ 6.         5.7222223]
 [ 2.5       16.277779 ]
 [ 6.5        6.611111 ]
 [ 2.        10.555555 ]
 [ 8.5        6.7222223]
 [ 2.5        6.7222223]
 [10.         2.9444444]
 [ 6.5        2.8333333]
 [ 4.        13.444445 ]
 [ 9.         5.8333335]
 [ 1.5       12.555555 ]
 [ 4.        13.555555 ]
 [ 2.        10.277778 ]
 [ 8.         3.8888888]
 [ 3.5       11.611111 ]
 [ 3.5       15.555555 ]
 [ 8.5        2.8333333]
 [ 8.5        6.8333335]
 [ 2.        10.5      ]
 [ 6.         2.9444444]
 [10.         5.1666665]
 [ 7.5        2.8888888]
 [ 4.        13.111111 ]
 [ 7.         0.944

100%|██████████| 180/180 [01:14<00:00,  2.42it/s]


20
[[15 15 15 ... 15 15 15]
 [15 15 15 ... 15 15 15]
 [ 3  3  3 ...  3  3  3]
 ...
 [14 14 14 ... 14 14 14]
 [ 3  3  3 ...  3  3  3]
 [ 8  8  8 ...  8  8  8]]
[  2   3   7  11  13  14  16  18  21  22  23  24  25  31  32  38  40  41
  43  45  46  47  48  49  50  51  52  54  55  56  58  60  62  63  64  69
  72  76  80  81  83  84  89  90  92  93  94  97 101 102 106 107 110 111
 113 114 117 118 120 121 122 123 124 127 129 132 133 134 136 139 141 142
 147 152 154 155 156 157 158 159 160 161 164 166 169 172 173 175 176 178]
[[3 3 3 ... 3 3 3]
 [4 4 4 ... 4 4 4]
 [3 3 3 ... 3 3 3]
 ...
 [5 5 6 ... 6 5 6]
 [3 3 3 ... 3 3 3]
 [3 3 3 ... 3 3 3]]
[  0   1   4   5   6   8   9  10  12  15  17  19  20  26  27  28  29  30
  33  34  35  36  37  39  42  44  53  57  59  61  65  66  67  68  70  71
  73  74  75  77  78  79  82  85  86  87  88  91  95  96  98  99 100 103
 104 105 108 109 112 115 116 119 125 126 128 130 131 135 137 138 140 143
 144 145 146 148 149 150 151 153 162 163 165 167 168 170 171 17

100%|██████████| 60/60 [00:22<00:00,  2.61it/s]

60
[2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 1 2 2 2]
2
[10 16]
[2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 1 2 2 2]
18
[ 0  1  2  3  4  5  6  7  8  9 11 12 13 14 15 17 18 19]
[[ 3.         8.944445 ]
 [ 9.         1.9444444]
 [ 0.5       16.222221 ]
 [ 1.         9.5      ]
 [ 6.5        5.7777777]
 [11.5        8.5      ]
 [10.         4.888889 ]
 [ 3.5       11.555555 ]
 [ 1.5        8.555555 ]
 [ 8.5        4.7777777]
 [ 7.5        6.6666665]
 [ 1.5       15.333333 ]
 [ 8.         6.6666665]
 [ 1.        11.277778 ]
 [10.5        5.7777777]
 [ 2.5        7.611111 ]
 [10.         2.8888888]
 [ 8.         2.8333333]
 [ 2.5       11.388889 ]
 [ 9.5        4.7777777]
 [ 0.5       12.5      ]
 [ 3.        13.444445 ]
 [ 0.5        8.5      ]
 [ 9.         2.8333333]
 [ 0.5       11.388889 ]
 [ 1.        15.333333 ]
 [11.         2.8333333]
 [10.5        5.7777777]
 [ 1.5       10.388889 ]
 [ 8.         2.9444444]
 [12.5        4.7222223]
 [13.         3.9444444]
 [ 1.5       12.333333 ]
 [ 8.5        0.944




In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams["font.family"] = 'Times New Roman'
#plt.rcParams["font.family"] = 'Hiragino Mincho ProN'
plt.rcParams["font.size"] = 25

plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
# 最終プロット
plt.plot(range(n_epoch), accuracy_all, marker='o', label='Training Accuracy')
plt.plot(range(n_epoch), accuracy_test_all, marker='s', label='Validation Accuracy')
plt.xticks(range(n_epoch))
plt.ylim(0, 1)
plt.legend(fontsize=12)
plt.grid(True)
plt.show()