In [None]:
from numba import njit
import numpy as np
from numpy import exp, zeros
import matplotlib.pyplot as plt
from scipy.integrate import odeint

@njit
def neuron_model(t, dt, p):
    """Simulate the neuron model with multiple compartments."""

    # Constantes
    F = 96485.3329  # C/mol, constante de Faraday
    R = 8.314  # J/(mol·K), constante dos gases ideais
    T = 310.15  # K, temperatura (37 °C)

    # Gating variables
    # Modificação: vetor para múltiplos compartimentos
    m = np.zeros(4)
    h = np.zeros(4)
    n = np.zeros(4)
    m_Ca = np.zeros(4)
    h_Ca = np.ones(4)
    m_NaP = np.zeros(4)
    h_NaP = np.ones(4)
    m_KS = np.zeros(4)
    h_KS = np.ones(4)
    c_Ca = np.zeros(4)

    # Storage for results
    T = int(t / dt)
    V = np.zeros(4)  # Array para armazenar V de todos os compartimentos
    Ca = np.zeros(4)  # Array para armazenar Ca de todos os compartimentos

    # Initial Conditions
    V[0] = p['V_rest']  # Soma membrane potential (V)
    V[1] = p['V_rest']  # Basal dendrite potential (V)
    V[2] = p['V_rest']  # Proximal apical dendrite potential (V)
    V[3] = p['V_rest']  # Distal apical dendrite potential (V)

    # Output
    V_soma_array = np.zeros(T)
    V_dend_basal_array = np.zeros(T)
    V_dend_proximal_array = np.zeros(T)
    V_dend_distal_array = np.zeros(T)

    for i in range(T):
        # Update gating variables using the current voltage
        alpha_m = p['alpha_m'](V)
        beta_m = p['beta_m'](V)
        m += (alpha_m * (1 - m) - beta_m * m) * dt

        alpha_h = p['alpha_h'](V)
        beta_h = p['beta_h'](V)
        h += (alpha_h * (1 - h) - beta_h * h) * dt

        alpha_n = p['alpha_n'](V)
        beta_n = p['beta_n'](V)
        n += (alpha_n * (1 - n) - beta_n * n) * dt

        m_Ca_infinito = p['m_Ca_infinito'](V)
        tau_m_Ca = p['tau_m_Ca'](V)
        m_Ca += (m_Ca_infinito - m_Ca) / tau_m_Ca * dt

        h_Ca_infinito = p['h_Ca_infinito'](V)
        tau_h_Ca = p['tau_h_Ca'](V)
        h_Ca += (h_Ca_infinito - h_Ca) / tau_h_Ca * dt

        alpha_m_NaP = p['alpha_m_NaP'](V)
        beta_m_NaP = p['beta_m_NaP'](V)
        m_NaP += (alpha_m_NaP * (1 - m_NaP) - beta_m_NaP * m_NaP) * dt

        h_NaP_infinito = p['h_NaP_infinito'](V)
        tau_h_NaP = p['tau_h_NaP'](V)
        h_NaP += (h_NaP_infinito - h_NaP) / tau_h_NaP * dt

        m_KS_infinito = p['m_KS_infinito'](V)
        tau_m_KS = p['tau_m_KS'](V)
        m_KS += (m_KS_infinito - m_KS) / tau_m_KS * dt

        h_KS_infinito = p['h_KS_infinito'](V)
        tau_h_KS = p['tau_h_KS'](V)
        h_KS += (h_KS_infinito - h_KS) / tau_h_KS * dt

        # Implementar as funções alpha e beta da variável de gating c da corrente I_C
        alpha_c_Ca = p['alpha_c_Ca'](V)
        beta_c_Ca = p['beta_c_Ca'](V)
        c_Ca += (alpha_c_Ca * (1 - c_Ca) - beta_c_Ca * c_Ca) * dt

        # Currents in soma compartment
        I_Na = p['g_Na'] * m**3 * h * (V - p['E_Na'])
        I_K = p['g_K'] * n**4 * (V - p['E_K'])
        I_leak = p['g_leak'] * (V - p['E_leak'])
        I_Ca = p['g_Ca'] * m_Ca**2 * h_Ca * (V - p['E_Ca'])
        I_NaP = p['g_NaP'] * m_NaP * h_NaP * (V - p['E_Na'])
        I_KS = p['g_KS'] * m_KS * h_KS * (V - p['E_K'])
        I_C = p['g_C'] * c_Ca**2 * (V - p['E_C'])

        # Currents in dendritic compartments
        I_dend_basal = (V[1] - V[0]) / (p['R_a'] * p['A_dend_basal'])
        I_dend_proximal = (V[2] - V[0]) / (p['R_a'] * p['A_dend_proximal'])
        I_dend_distal = (V[3] - V[4]) / (p['R_a'] * p['A_dend_distal'])

        # Update membrane potentials
        dV_soma = (-I_Na[0] - I_K[0] - I_leak[0] - I_Ca[0] - I_NaP[0] - I_KS[0] - I_C[0] + I_dend_basal + I_dend_proximal) / p['C_m']
        V[0] += dV_soma * dt

        # Update dendritic membrane potentials
        dV_dend_basal = (-I_Na[1] - I_K[1] - I_leak[1] - I_Ca[1] - I_NaP[1] - I_KS[1] - I_C[1] - I_dend_basal) / p['C_m']
        V[1] += dV_dend_basal * dt

        dV_dend_proximal = (-I_Na[2] - I_K[2] - I_leak[2] - I_Ca[2] - I_NaP[2] - I_KS[2] - I_C[2] + I_dend_proximal - I_dend_distal) / p['C_m']
        V[2] += dV_dend_proximal * dt

        dV_dend_distal = (-I_Na[3] - I_K[3] - I_leak[3] - I_Ca[3] - I_NaP[3] - I_KS[3] - I_C[3] + I_dend_distal) / p['C_m']
        V[3] += dV_dend_distal * dt


        # Update calcium concentration
        dCa_dt = -p['phi_soma']/(p['F']*p['Vol_soma']) * I_Ca[0] + (p['Ca_rest'] - Ca)/p['tau_Ca_soma']
        Ca[0] += dCa_dt * dt

        dCa_dt = -p['phi_dendrites']/(p['F']*p['Vol_basal']) * I_Ca[0] + (p['Ca_rest'] - Ca)/p['tau_Ca_basal']
        Ca[1] += dCa_dt * dt

        dCa_dt = -p['phi_dendrites']/(p['F']*p['Vol_proximal']) * I_Ca[0] + (p['Ca_rest'] - Ca)/p['tau_Ca_proximal']
        Ca[2] += dCa_dt * dt

        dCa_dt = -p['phi_dendrites']/(p['F']*p['Vol_distal']) * I_Ca[0] + (p['Ca_rest'] - Ca)/p['tau_Ca_distal']
        Ca[3] += dCa_dt * dt

        # Store results
        V_soma_array[i] = V[0]
        V_dend_basal_array[i] = V[1]
        V_dend_proximal_array[i] = V[2]
        V_dend_distal_array[i] = V[3]

    return V_soma_array, V_dend_basal_array, V_dend_proximal_array, V_dend_distal_array


### Funções para as correntes ###

# Na
# J_Na
@njit
def f_J_Na(V, m, h, g_Na, E_Na):
    return g_Na * m**3 * h * (V - E_Na)

@njit
def m_inf_Na(V):
    return alpha_Na_m(V)/(alpha_Na_m(V) + beta_Na_m(V))

@njit
def tau_m_Na(V):
    return 1/(alpha_Na_m(V) + beta_Na_m(V))

@njit
def alpha_Na_m(V):
    return -0.2816*(V + 28)/(-1 + exp(-(V + 28)/9.3))

@njit
def beta_Na_m(V):
    return 0.2464*(V + 1)/(-1 + exp((V + 1)/6))

@njit
def h_inf_Na(V):
    return alpha_Na_h(V)/(alpha_Na_h(V) + beta_Na_h(V))

@njit
def tau_h_Na(V):
    return 1/(alpha_Na_h(V) + beta_Na_h(V))

@njit
def alpha_Na_h(V):
    return 0.098*exp(-(V + 43.1)/20)

@njit
def beta_Na_h(V):
    return 1.4/(1 + exp(-(V + 13.1)/10))

# J_NaP
@njit
def f_J_NaP(V, m, h, g_NaP, E_Na):
    return g_NaP*m*h*(V - E_Na)

@njit
def m_inf_NaP(V):
    return alpha_NaP_m(V)/(alpha_NaP_m(V) + beta_NaP_m(V))

@njit
def tau_m_NaP(V):
    return 1/(alpha_NaP_m(V) + beta_NaP_m(V))

@njit
def alpha_NaP_m(V):
    return -0.2816*(V + 12)/(-1 + exp(-(V + 12)/9.3))

@njit
def beta_NaP_m(V):
    return 0.2464*(V - 15)/(-1 + exp((V - 15)/6))

@njit
def h_inf_NaP(V):
    return alpha_NaP_h(V)/(alpha_NaP_h(V) + beta_NaP_h(V))

@njit
def tau_h_NaP(V):
    return 1/(alpha_NaP_h(V) + beta_NaP_h(V))

@njit
def alpha_NaP_h(V):
    return 2.8e-5*exp(-(V + 42.8477)/4.0248)

@njit
def beta_NaP_h(V):
    return 0.02/(1 + exp(-(V - 413.9284)/148.2589))

# K
# DR
@njit
def f_J_DR(V, n, g_DR, E_K):
    return g_DR * n**4 * (V - E_K)

@njit
def n_inf_DR(V):
    return alpha_DR_n(V)/(alpha_DR_n(V) + beta_DR_n(V))

@njit
def tau_n_DR(V):
    return 1/(alpha_DR_n(V) + beta_DR_n(V))

@njit
def alpha_DR_n(V):
    return 0.016*(V + 34.9)/(-1 + exp(-(V + 34.9)/10))

@njit
def beta_DR_n(V):
    return 0.25*exp(-(V + 60)/80)

# KS
@njit
def f_J_KS(V, m, g_KS, E_K):
    return g_KS*m*(V - E_K)

@njit
def m_inf_KS(V):
    return alpha_KS_m(V)/(alpha_KS_m(V) + beta_KS_m(V))

@njit
def tau_m_KS(V):
    return 1/(alpha_KS_m(V) + beta_KS_m(V))

@njit
def alpha_KS_m(V):
    return 0.02/(1 + exp(-(V + 45)/12.4))

@njit
def beta_KS_m(V):
    return 0.022*(-50 - V)/(-1 + exp((-50 - V)/12.4))

# Ca
@njit
def f_J_Ca(V, m, h, g_Ca, E_Ca):
    return g_Ca*m**2*h*(V - E_Ca)

@njit
def m_inf_Ca(V):
    return alpha_Ca_m(V)/(alpha_Ca_m(V) + beta_Ca_m(V))

@njit
def tau_m_Ca(V):
    return 1/(alpha_Ca_m(V) + beta_Ca_m(V))

@njit
def alpha_Ca_m(V):
    return 1.6e-4*(V + 30)/(-1 + exp(-(V + 30)/9))

@njit
def beta_Ca_m(V):
    return 0.02*exp(-(V + 30)/30)

@njit
def h_inf_Ca(V):
    return alpha_Ca_h(V)/(alpha_Ca_h(V) + beta_Ca_h(V))

@njit
def tau_h_Ca(V):
    return 1/(alpha_Ca_h(V) + beta_Ca_h(V))

@njit
def alpha_Ca_h(V):
    return 0.007*exp(-(V + 85)/67)

@njit
def beta_Ca_h(V):
    return 0.03/(1 + exp(-(V + 17)/3))

# Atraso sináptico
@njit
def D(x):
    return 0.01 if x > 2 else 0

# Parametros do modelo
p = {
    'alpha_m': alpha_Na_m,
    'beta_m': beta_Na_m,
    'alpha_h': alpha_Na_h,
    'beta_h': beta_Na_h,
    'alpha_n': alpha_DR_n,
    'beta_n': beta_DR_n,
    'm_Ca_infinito': m_inf_Ca,
    'tau_m_Ca': tau_m_Ca,
    'alpha_h_Ca': alpha_Ca_h,
    'beta_h_Ca': beta_Ca_h,
    'alpha_m_NaP': alpha_NaP_m,
    'beta_m_NaP': beta_NaP_m,
    'alpha_m_KS': alpha_KS_m,
    'beta_m_KS': beta_KS_m,
}

# Parâmetros de simulação
t = 10.0  # Duração da simulação (ms)
dt = 0.01  # Passo de tempo (ms)

# Rodando o modelo
V_soma, V_dend_basal, V_dend_proximal, V_dend_distal = neuron_model(t, dt, p)


# Funções de ativação e inativação
@njit
def m_inf_Na(V):
    return 1 / (1 + np.exp(-(V + 40) / 10))

@njit
def h_inf_Na(V):
    return 1 / (1 + np.exp((V + 65) / 10))

@njit
def tau_m_Na(V):
    return 1 / (np.exp((V + 40) / 10) + np.exp(-(V + 40) / 10))

@njit
def tau_h_Na(V):
    return 1 / (np.exp((V + 65) / 10) + np.exp(-(V + 65) / 10))

@njit
def m_inf_NaP(V):
    return 1 / (1 + np.exp(-(V + 45) / 7))

@njit
def tau_m_NaP(V):
    return 1

@njit
def n_inf_DR(V):
    return 1 / (1 + np.exp(-(V + 35) / 10))

@njit
def tau_n_DR(V):
    return 1 / (np.exp((V + 35) / 10) + np.exp(-(V + 35) / 10))

@njit
def m_inf_DR(V):
    return 1 / (1 + np.exp(-(V + 30) / 10))

@njit
def tau_m_DR(V):
    return 1 / (np.exp((V + 30) / 10) + np.exp(-(V + 30) / 10))

@njit
def m_inf_KS(V):
    return 1 / (1 + np.exp(-(V + 25) / 10))

@njit
def tau_m_KS(V):
    return 1 / (np.exp((V + 25) / 10) + np.exp(-(V + 25) / 10))

@njit
def m_inf_Ca(V):
    return 1 / (1 + np.exp(-(V + 20) / 5))

@njit
def h_inf_Ca(V):
    return 1 / (1 + np.exp((V + 45) / 5))

@njit
def tau_m_Ca(V):
    return 1 / (np.exp((V + 20) / 5) + np.exp(-(V + 20) / 5))

@njit
def tau_h_Ca(V):
    return 1 / (np.exp((V + 45) / 5) + np.exp(-(V + 45) / 5))

# Correntes iônicas
@njit
def I_Na(V, m_Na, h_Na, p):
    return p['g_Na'] * m_Na**3 * h_Na * (V - p['E_Na'])

@njit
def I_NaP(V, m_NaP, p):
    return p['g_NaP'] * m_NaP * (V - p['E_NaP'])

@njit
def I_DR(V, n, p):
    return p['g_DR'] * n**4 * (V - p['E_K'])

@njit
def I_KS(V, m_KS, p):
    return p['g_KS'] * m_KS * (V - p['E_K'])

@njit
def I_Ca(V, m_Ca, h_Ca, p):
    return p['g_Ca'] * m_Ca**2 * h_Ca * (V - p['E_Ca'])

# Derivada do potencial de membrana
@njit
def dVdt(V, m_Na, h_Na, m_NaP, n, m_DR, m_KS, m_Ca, h_Ca, p):
    I_tot = I_Na(V, m_Na, h_Na, p) + I_NaP(V, m_NaP, p) + I_DR(V, n, p) + I_KS(V, m_KS, p) + I_Ca(V, m_Ca, h_Ca, p)
    return -(I_tot + p['I_ext']) / p['C_m']

# Derivada de uma variável de estado genérica
@njit
def dxdt(x, x_inf, tau_x):
    return (x_inf - x) / tau_x

# Equações do modelo completo
@njit
def model_eqs(V, m_Na, h_Na, m_NaP, n, m_DR, m_KS, m_Ca, h_Ca, p):
    return np.vstack((
        dVdt(V, m_Na, h_Na, m_NaP, n, m_DR, m_KS, m_Ca, h_Ca, p),
        dxdt(m_Na, m_inf_Na(V), tau_m_Na(V)),
        dxdt(h_Na, h_inf_Na(V), tau_h_Na(V)),
        dxdt(m_NaP, m_inf_NaP(V), tau_m_NaP(V)),
        dxdt(n, n_inf_DR(V), tau_n_DR(V)),
        dxdt(m_DR, m_inf_DR(V), tau_m_DR(V)),
        dxdt(m_KS, m_inf_KS(V), tau_m_KS(V)),
        dxdt(m_Ca, m_inf_Ca(V), tau_m_Ca(V)),
        dxdt(h_Ca, h_inf_Ca(V), tau_h_Ca(V))
    ))

# Função para resolver as equações do modelo com o método Runge-Kutta de 4ª ordem
@njit
def solve_model_equations(dt=0.01, sample_rate=100, t_total=100.0, p=None):
    """Calcula a solução numérica do sistema de equações diferenciais do modelo do neurônio
    usando o método de Runge-Kutta de 4ª ordem.

    Parâmetros
    ----------
    dt: float
        Passo de tempo para a solução numérica pelo método de Runge-Kutta de 4ª ordem.
    sample_rate: int
        Taxa de amostragem para a saída de potencial de membrana.
    t_total: float
        Tempo total de simulação (em ms).
    p: dicionário
        Parâmetros do modelo.

    Retorno
    ----------
    resultados: numpy 2D-array
        Potenciais de membrana ao longo do tempo.
    """

    n_steps = int(t_total / dt)

    # Inicialização das variáveis
    V = p['V_rest']
    m_Na = m_inf_Na(p['V_rest'])
    h_Na = h_inf_Na(p['V_rest'])
    m_NaP = m_inf_NaP(p['V_rest'])
    n = n_inf_DR(p['V_rest'])
    m_DR = m_inf_DR(p['V_rest'])
    m_KS = m_inf_KS(p['V_rest'])
    m_Ca = m_inf_Ca(p['V_rest'])
    h_Ca = h_inf_Ca(p['V_rest'])

    V_out = np.zeros(n_steps)

    # Metodo de Runge-Kutta de 4ª ordem
    for it in range(n_steps):
        k1 = dt * model_eqs(V, m_Na, h_Na, m_NaP, n, m_DR, m_KS, m_Ca, h_Ca, p)
        k2 = dt * model_eqs(V + 0.5 * k1[0], m_Na + 0.5 * k1[1], h_Na + 0.5 * k1[2], m_NaP + 0.5 * k1[3],
                            n + 0.5 * k1[4], m_DR + 0.5 * k1[5], m_KS + 0.5 * k1[6], m_Ca + 0.5 * k1[7],
                            h_Ca + 0.5 * k1[8], p)
        k3 = dt * model_eqs(V + 0.5 * k2[0], m_Na + 0.5 * k2[1], h_Na + 0.5 * k2[2], m_NaP + 0.5 * k2[3],
                            n + 0.5 * k2[4], m_DR + 0.5 * k2[5], m_KS + 0.5 * k2[6], m_Ca + 0.5 * k2[7],
                            h_Ca + 0.5 * k2[8], p)
        k4 = dt * model_eqs(V + k3[0], m_Na + k3[1], h_Na + k3[2], m_NaP + k3[3],
                            n + k3[4], m_DR + k3[5], m_KS + k3[6], m_Ca + k3[7],
                            h_Ca + k3[8], p)

        V     += (k1[0]  + 2*k2[0]  + 2*k3[0]  + k4[0]) / 6
        m_Na  += (k1[1]  + 2*k2[1]  + 2*k3[1]  + k4[1]) / 6
        h_Na  += (k1[2]  + 2*k2[2]  + 2*k3[2]  + k4[2]) / 6
        m_NaP += (k1[3]  + 2*k2[3]  + 2*k3[3]  + k4[3]) / 6
        n     += (k1[4]  + 2*k2[4]  + 2*k3[4]  + k4[4]) / 6
        m_DR  += (k1[5]  + 2*k2[5]  + 2*k3[5]  + k4[5]) / 6
        m_KS  += (k1[6]  + 2*k2[6]  + 2*k3[6]  + k4[6]) / 6
        m_Ca  += (k1[7]  + 2*k2[7]  + 2*k3[7]  + k4[7]) / 6
        h_Ca  += (k1[8]  + 2*k2[8]  + 2*k3[8]  + k4[8]) / 6

        # Armazenar os resultados
        if it % sample_rate == 0:
            V_out[it // sample_rate] = V

    return V_out

# Resolução do modelo
V_out = solve_model_equations(dt=0.01, sample_rate=10, t_total=10.0, p=p)
