<a href="https://colab.research.google.com/github/tonyflow90/dl_ea02/blob/main/LIFModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
# These imports will be used in the notebook
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [8]:
# A basic LIF neuron
class LIFNeuron(object):
    
    def __init__(self, u_rest=0.0, u_thresh=1.0, tau_rest=4.0, r=1.0, tau=10.0):
        
        # Membrane resting potential in mV
        self.u_rest = u_rest
        # Membrane threshold potential in mV
        self.u_thresh = u_thresh
        # Duration of the resting period in ms
        self.tau_rest = tau_rest
        # Membrane resistance in Ohm
        self.r = r
        # Membrane time constant in ms
        self.tau = tau
        
        # Instantiate a graph for this neuron
        self.graph = tf.Graph()
        
        # Build the graph
        with self.graph.as_default():
        
            # Variables and placeholders
            self.get_vars_and_ph()
            
            # Operations
            self.input = self.get_input_op()
            self.potential = self.get_potential_op()
            # Note that input is a prerequisite of potential, so it will
            # always be evaluated when potential is
            
    # Variables and placeholders
    def get_vars_and_ph(self):

        # The current membrane potential
        self.u = tf.Variable(self.u_rest, dtype=tf.float32, name='u')
        # The duration left in the resting period (0 most of the time except after a neuron spike)
        self.t_rest = tf.Variable(0.0, dtype=tf.float32, name='t_rest')
        # Input current
        self.i_app = tf.placeholder(dtype=tf.float32, name='i_app')
        # The chosen time interval for the stimulation in ms
        self.dt = tf.placeholder(dtype=tf.float32, name='dt')

    # Evaluate input current
    def get_input_op(self):
        
        return self.i_app
        
    # Neuron behaviour during integration phase (below threshold)
    def get_integrating_op(self):

        # Get input current
        i_op = self.get_input_op()

        # Update membrane potential
        du_op = tf.divide(tf.subtract(tf.multiply(self.r, i_op), self.u), self.tau) 
        u_op = self.u.assign_add(du_op * self.dt)
        # Refractory period is 0
        t_rest_op = self.t_rest.assign(0.0)
        
        return u_op, t_rest_op

    # Neuron behaviour during firing phase (above threshold)    
    def get_firing_op(self):                  

        # Reset membrane potential
        u_op = self.u.assign(self.u_rest)
        # Refractory period starts now
        t_rest_op = self.t_rest.assign(self.tau_rest)

        return u_op, t_rest_op

    # Neuron behaviour during resting phase (t_rest > 0)
    def get_resting_op(self):

        # Membrane potential stays at u_rest
        u_op = self.u.assign(self.u_rest)
        # Refractory period is decreased by dt
        t_rest_op = self.t_rest.assign_sub(self.dt)
        
        return u_op, t_rest_op

    def get_potential_op(self):
        
        return tf.case(
            [
                (self.t_rest > 0.0, self.get_resting_op),
                (self.u > self.u_thresh, self.get_firing_op),
            ],
            default=self.get_integrating_op
        )

In [9]:
# Simulation with square input currents

# Duration of the simulation in ms
T = 200
# Duration of each time step in ms
dt = 1
# Number of iterations = T/dt
steps = int(T / dt)
# Output variables
I = []
U = []

neuron = LIFNeuron()
    
with tf.Session(graph=neuron.graph) as sess:

    sess.run(tf.global_variables_initializer())    

    for step in range(steps):
        
        t = step * dt
        # Set input current in mA
        if t > 10 and t < 30:
            i_app = 0.5
        elif t > 50 and t < 100:
            i_app = 1.2
        elif t > 120 and t < 180:
            i_app = 1.5
        else:
            i_app = 0.0

        feed = { neuron.i_app: i_app, neuron.dt: dt}
        
        u = sess.run(neuron.potential, feed_dict=feed)

        I.append(i_app)
        U.append(u)

AttributeError: ignored