In [None]:
import numpy as np
kNoSpike = 100
decay_rate = 1e-4 # random set, to fix
decay_params={'rate':decay_rate, 'rate_inverse':1/decay_rate} 
layer_size = [784, 512, 10]
weights = [np.zeros([512, 784]), np.zeros([10, 512])]
fire_threshold=1.0
M_E=1e5 # unknow yet
#Minimum argument for the main branch of the Lambert W function.
kMinLambertArg = -1.0 / M_E
#Maximum argument for which gsl_sf_lambert_W0 produces a valid result.
kMaxLambertArg = 1.7976131e+308



In [None]:
kNearBranchCutoff = -0.3235
kE = 2.718281828459045
def LambertW0InitialGuess(x):
  # Sqrt approximation near branch cutoff.
  if x < kNearBranchCutoff:
    return -1.0 + sqrt(2.0 * (1 + kE * x))
  # Taylor series between [-1/e and 1/e].
  if x > kNearBranchCutoff and x < -kNearBranchCutoff:
    return x * (1 + x * (-1 + x * (3.0 / 2.0 - 8.0 / 3.0 * x)))

  # Series of piecewise linear approximations.
  if x < 0.6:
       return 0.23675531078855933 + (x - 0.3) * 0.5493610866617109;
  if x < 0.8999999999999999:
    return 0.4015636367870726 + (x - 0.6) * 0.4275644294878729;
  if x < 1.2:
    return 0.5298329656334344 + (x - 0.8999999999999999) * 0.3524368357714513;
  if x < 1.5:
    return 0.6355640163648698 + (x - 1.2) * 0.30099113800452154;
  if x < 1.8:
      return 0.7258613577662263 + (x - 1.5) * 0.2633490154764343;
  if x < 2.0999999999999996:
    return 0.8048660624091566 + (x - 1.8) * 0.2345089875713013;
  if x < 2.4:
    return 0.8752187586805469 + (x - 2.0999999999999996) * 0.2116494532726034;
  if x < 2.6999999999999997:
    return 0.938713594662328 + (x - 2.4) * 0.19305046534383152;
  if x < 2.9999999999999996:
    return 0.9966287342654774 + (x - 2.6999999999999997) * 0.17760053566187495;

  # Asymptotic approximation.
  l = log(x)
  ll = log(l)
  return l - ll + ll / l

In [None]:
kReciprocalE = 0.36787944117
kDesiredAbsoluteDifference = 1e-3
kNumMaxIters = 10

def LambertW0(x):
  if x <= -kReciprocalE:
      return None, False
  if x == 0.0:
    return 0, True
  if x == -kReciprocalE:
    return -1.0, True

  # Current guess.
  w_n = LambertW0InitialGuess(x)
  have_convergence = False

  # Fritsch iteration.
  for i in range(kNumMaxIters):
    z_n = log(x / w_n) - w_n
    q_n = 2.0 * (1.0 + w_n) * (1.0 + w_n + 2.0 / 3.0 * z_n)
    e_n = (z_n / (1.0 + w_n)) * ((q_n - z_n) / (q_n - 2.0 * z_n))
    w_n *= (1.0 + e_n)
    # Done this way as the log is the expensive part above.
    if abs(z_n) < kDesiredAbsoluteDifference:
      have_convergence = True
      break
  return w_n, have_convergence

In [None]:
def ExponentiateSortedValidSpikes(activations, sorted_indices, decay_rate):
  exp_activations = np.zeros_like(activations)
  exp_activations.fill(kNoSpike)
  i = 0
  while i < len(sorted_indices) and activations[sorted_indices[i]] < kNoSpike:
    exp_activations[sorted_indices[i]]  = np.exp(decay_rate * activations[sorted_indices[i]])
    i += 1
  return exp_activations


In [None]:
def ActivateNeuronAlpha(weight, activation, exp_activation, sorted_indices, threshold):
  #causal_set, a, b, w, decay_params
  A = 0.0
  B = 0.0
  W = 0.0
  spike_time = kNoSpike
  causal_set = np.zeros_like(activation)
    # input spike one by one
  for spike_idx in sorted_indices:
    if spike_time <= activation[spike_idx]:
        return spike_time # no need to integrate more
        # Otherwise, integrate this spike
        # Reset spike time, in case an inhibitory input cancels a potential spike.
    spike_time = kNoSpike
    causal_set[spike_idx] = 1

    w_exp_z = weight[spike_idx] * exp_activation[spike_idx]
    A += w_exp_z
    B += w_exp_z * activation[spike_idx]

# The value of the first derivative of the activation function in the
# intersection point with the fire threshold is given by *A multiplied by a
# never-negative value. Thus, if *A is negative the intersection will be in
# a decreasing-potential area, and thus not a spike.
    if A < 0:
      continue
    b_over_a = B/A
    lambert_arg = -decay_params['rate'] * threshold / A * np.exp(decay_params['rate'] * b_over_a)
    if lambert_arg >= kMinLambertArg and lambert_arg <= kMaxLambertArg:
      val, convergence = LambertW0(lambert_arg)
      assert convergence, "Error computing Lambert W on: %f" % (lambert_arg)
      W = val
      spike_time = b_over_a - W * decay_params.rate_inverse()

      # For inhibitory weights, this might be a false alarm.
      # This is not the same as spike_time < inputs[spike_ind]: it is also true for NaNs.
      if not (spike_time >= activation[spike_idx]):
          spike_time = kNoSpike
    #END_IF
  #END_FOR
  # If we get here, either there is no spike, in which case
  # all presynaptic neurons are to blame, or there is eventually
  # a spike caused by all presynaptic inputs.
  if (spike_time == kNoSpike):
    causal_set=True
  return spike_time, A,B,W,causal_set
    

In [13]:
activation= np.zeros([784]) # input with batch
for layer in range(len(layer_size)-1):
    sorted_indices = np.argsort(activation)# sort indices from small to large
    exp_activation = ExponentiateSortedValidSpikes(activation, sorted_indices, decay_rate)
    activation_next = np.zeros(layer_size[layer+1])
    A = np.zeros_like(activation_next)
    B = np.zeros_like(activation_next)
    W = np.zeros_like(activation_next)
    causal_set = np.zeros([layer_size[layer+1], layer_size[layer]])
    for n in range(layer_size[layer+1]):
        act,a,b,w,c = ActivateNeuronAlpha(weights[layer][n], activation, exp_activation, sorted_indices, fire_threshold)
        activation_next[n]=act 
        A[n]=a
        B[n]=b
        W[n]=w
        causal_set[n]=c
    activation=activation_next


  b_over_a = B/A
  lambert_arg = -decay_params['rate'] * threshold / A * np.exp(decay_params['rate'] * b_over_a)
