### Defining the loss function.

We want our loss function to combine two aspects: firing correctness (HH firing event) and energy consumption.

1. Firing term: $F(t)$ to be observed firing rate and $F_{targ}(t)$ to be target firing time at t s.t. $T$ is total time of simulation. We can measure our "correctness" as $L_{fire} = \frac{1}{T}\int_0^T (F(t)-F_{targ}(t))^2 dt$. 

2. Energy Consumption: Encourages waveform to be as energy efficient as possible. Suppose we simplify our problem ot step current input, then we are looking at duration and amplitude of the step.Suppose we have some input described as I, $L_{energy} = I_{max} \times I_{dur}$. 

Combining these terms, we can create our lagrangian. 

$$\mathcal{L} = \alpha L_{fire} + \beta I_{dur}$$

s.t. $\alpha, \beta$ are weighting coefficients 

In [1]:
import numpy as np
import scipy.integrate 
from scipy.integrate import simps


#step simplified case
def initial_stim(t, I_max, I_duration): 
    return I_max if 0 <= t <= I_duration else 0.0

#defining the cost function

#energy calculation 
def calc_energy(t, wvfrm): 
    power = wvfrm**2
    energy = simps(power, t)
    return energy 

#step simplified case
#def initial_stim(t, I_max, I_duration): 
    #return I_max if 0 <= t <= I_duration else 0.0

#binary loss function for if we measure "firing" as if threshold is met -- binary firing instance 
#y_true will be generated by NEURON instance, will add call to this when I clean up all code 

def lagrangian_bin(stim, y_true, t_span, dt, alpha, beta, thresh=-55.0,):
    #call hodgkin huxley model
    fired = bin_firing(stim, t_span, dt, thresh)
    
    if fired:
        loss = calc_energy(t_span, stim) 
    else: 
        #penalize if no firing via adding penalty
        penalty = 1e6
        loss = calc_energy(t_span, stim) + penalty
    return loss

#alternative generation of loss function -- we use NEURON to create are targetted behavior -- how many firings we want to see 
#train loss function using recovery -- want as many action potentials as seen by NEURON -- ground truth
def ap_count(voltage, thresh = -55):
    #counts the number of action potentials from the np array 
    crossings = np.where(np.diff((voltage_trace >= threshold).astype(int)) == 1)[0]
    #an action potential will cross the threshold twice one in the depolarization phase and one in the repolarization
    return len(crossings) // 2


    



SyntaxError: non-default argument follows default argument (693476765.py, line 45)

In [3]:
def elu_norm(voltage, thresh, alpha = 0.001):
    exponentiated_values = np.where(voltage <= thresh, alpha*(np.exp(voltage)-np.exp(thresh)), voltage)
    return exponentiated_values


def lagrange_seq(I_array, y_true, y_pred, alpha = .001, beta = 1e3, gamma = 1e2, thresh = -55):
    
    #penalty term for deviation from desired number of action potentials
    #ap_true = ap_count(y_true, thresh)
    #ap_pred = ap_count(y_pred, thresh)
    #mse_ap = (ap_pred - ap_true)**2
    normed_true = elu_norm(y_true, alpha)
    normed_pred = elu_norm(y_pred)
    mse_ap= (normed_true-normed_pred, alpha)
    
    #pred_ener = np.sum(np.square(input_current))
    #true_ener = np.sum(np.square(current_true))
   # mse_energy = np.mean(np.square(pred_ener-true_ener))
    
    loss = gamma * 1/len(y_true)* mse_ap + beta * np.linalg.norm(I_array)
    return loss

In [None]:

#should this be binary cross entropy loss instead? 