In [None]:
import numpy as np
from scipy.special import binom as binom_fun
from scipy.special import factorial as sp_factorial

import matplotlib.pyplot as plt
from scipy.integrate import quad

import theano
import theano.tensor as T
from theano.graph.op import Op
from theano.graph.basic import Apply
from theano.compile.io import In

#import quadpy

%matplotlib nbagg

In [None]:
def p_nu(NU,gamma,delta,nu_max):

    return gamma / ( nu_max * T.sqrt( -np.pi * T.log( NU / nu_max ) ) ) * \
        T.exp( - delta**2/2.) * ( NU / nu_max )**(gamma**2 - 1) * \
        T.cosh( gamma * delta * T.sqrt( -2 * T.log( NU / nu_max) ) )

def poisson_spikes(nu,N,T_total):
    return (nu*T_total)**N / T.gamma(N+1) * T.exp(-nu*T_total)

In [None]:
T_total = 600.#T.dscalar('T')      # measurement time

N_AP = T.dvector('N_AP')           # number of action potentials (should be vector)
N = T.dscalar('N')

nu = T.dscalar('nu')

gamma = T.dscalar('gamma')
delta = T.dscalar('delta')
nu_max = T.dscalar('nu_max')

In [99]:
class integrateOut(Op):
    """
    Integrate out a variable from an expression, computing
    the definite integral w.r.t. the variable specified
    !!! Only implemented in this for scalars !!!
    """
    
    #vectorize = True
    
    def __init__(self,f,t,vectorize,*args,**kwargs):
        super(integrateOut,self).__init__()
        self.f = f
        self.t = t
        self.vectorize = vectorize

    def make_node(self,*inputs):
        self.fvars=list(inputs)
        print(f'[make node]: inputs: {inputs}, fvars:{self.fvars}')
        # This will fail when taking the gradient... don't be concerned
        try:
            self.gradF = T.jacobian(self.f,self.fvars)
        except:
            self.gradF = None
            
        if self.vectorize:
            return Apply(self,self.fvars,[T.dvector().type()])
        else:
            return Apply(self,self.fvars,[T.dscalar().type()])

    def perform(self,node, inputs, output_storage):       
        # create a function to evaluate the integral
        
        ## integrate the function from 0 to maximum firing rate nu_max
        N_AP = inputs[0]
        nu_max = inputs[3]
        if self.gradF is None:
            print(f'[perform (grad)]: inputs: {inputs}, fvars:{self.fvars}')
        else:
            print(f'[perform (fun)]: inputs: {inputs}, fvars:{self.fvars}')
        if self.vectorize:
            #N = T.lscalar('N')
            #print([self.t]+[N] + self.fvars[1:])
            #print(theano.clone(self.f,replace={self.fvars[0]:N}))
            #theano.printing.debugprint(self.f)
            f = theano.function([self.t]+self.fvars,self.f)
            
            output = np.zeros_like(N_AP,dtype='float64')
            for i,N in enumerate(N_AP):
                args = inputs[:]
                args[0] = np.array([N])   # necessary to be 1-dim vector to satisfy function-blueprint
                print(f'args: {args}')
                output[i] = quad(f,0,nu_max,args=tuple(args))[0]
            output_storage[0][0] = output
        else:
            f = theano.function([self.t]+self.fvars,self.f)
            output_storage[0][0] = np.array(quad(f,0,nu_max,args=tuple(inputs))[0],dtype='float64')

    def grad(self,inputs,output_grads):
        nu_max = inputs[3]
        
        giv = {}
        for v,v_new in zip(self.fvars[1:],list(inputs)[1:]):
            giv[v] = v_new
        print(f'giv: {giv}')
        print(self.gradF)
        return [T.mean(integrateOut(theano.clone(g,replace=giv),self.t,self.vectorize)(*inputs)*output_grads[0]) \
            for g in self.gradF]

In [None]:
theano.config.optimizer = 'fast_run'
theano.config.exception_verbosity = 'high'
theano.config.on_unused_input = 'warn'
theano.config.mode = 'FAST_RUN'

p_N_AP = integrateOut(p_nu(nu,gamma,delta,nu_max)*poisson_spikes(nu,N_AP,T_total),nu,vectorize=True)(N_AP,gamma,delta,nu_max)
func_p = theano.function([N_AP,gamma,delta,nu_max],p_N_AP)
func_vals = func_p([0,3,5,10],1.2,4.8,30.)
print(func_vals)

pGrad = T.jacobian(p_N_AP,[gamma,delta,nu_max],consider_constant=[N_AP])
funcGrad = theano.function([N_AP,gamma,delta,nu_max],pGrad,mode='DebugMode',on_unused_input='warn')

[make node]: inputs: (N_AP, gamma, delta, nu_max), fvars:[N_AP, gamma, delta, nu_max]
[perform (fun)]: inputs: [array([ 0,  3,  5, 10]), array(1.2), array(4.8), array(30.)], fvars:[N_AP, gamma, delta, nu_max]




args: [array([0]), array(1.2), array(4.8), array(30.)]
args: [array([3]), array(1.2), array(4.8), array(30.)]
args: [array([5]), array(1.2), array(4.8), array(30.)]
args: [array([10]), array(1.2), array(4.8), array(30.)]
[perform (fun)]: inputs: [array([ 0,  3,  5, 10]), array(1.2), array(4.8), array(30.)], fvars:[N_AP, gamma, delta, nu_max]
args: [array([0]), array(1.2), array(4.8), array(30.)]
args: [array([3]), array(1.2), array(4.8), array(30.)]
args: [array([5]), array(1.2), array(4.8), array(30.)]
args: [array([10]), array(1.2), array(4.8), array(30.)]
[1.06234782e-09 3.62980190e-02 2.32467904e-02 1.21495837e-02]
giv: {gamma: gamma, delta: delta, nu_max: nu_max}
[for{cpu,scan_fn}.0, for{cpu,scan_fn}.1, for{cpu,scan_fn}.2, for{cpu,scan_fn}.3]
[make node]: inputs: (N_AP, gamma, delta, nu_max), fvars:[N_AP, gamma, delta, nu_max]
[make node]: inputs: (N_AP, gamma, delta, nu_max), fvars:[N_AP, gamma, delta, nu_max]
[make node]: inputs: (N_AP, gamma, delta, nu_max), fvars:[N_AP, gamma,

In [None]:
grad_vals = funcGrad([0,3,5],1.2,5.8,30.)
print(grad_vals)

In [13]:
theano.config.optimizer = 'fast_run'

p_N_AP_arr, updates = theano.scan(
    fn = lambda N, gamma, delta, nu_max : integrateOut(p_nu(nu,gamma,delta,nu_max)*poisson_spikes(nu,N,T_total),nu,vectorize=False)(N,gamma,delta,nu_max), 
    sequences=[N_AP],
    non_sequences=[gamma,delta,nu_max],
)
func_p = theano.function([N_AP,gamma,delta,nu_max],p_N_AP_arr)
func_vals = func_p([0.,3.,5.,10.],1.2,4.8,30.)
print(func_vals)

pGrad = T.jacobian(p_N_AP_arr,[gamma,delta,nu_max],consider_constant=[N_AP])
funcGrad = theano.function([N_AP,gamma,delta,nu_max],pGrad,mode='FAST_RUN')

[1.06234782e-09 3.62980190e-02 2.32467904e-02 1.21495837e-02]


In [14]:
grad_vals = funcGrad([0.,3.,5.],1.2,5.8,30.)
print(grad_vals)

[array([2.62735034e-09, 1.10484408e-01, 6.93019521e-02]), array([-6.10352823e-10, -2.07417269e-02, -1.38951575e-02]), array([5.51702792e-12, 6.86296925e-09, 1.05980551e-04])]


In [16]:
rng = np.random.RandomState(42)
def p_N_AP_fun(N_AP,gamma,delta,nu_max):
    #p_N_AP = integrateOut(p_nu(nu,gamma,delta,nu_max)*poisson_spikes(nu,N,T_total),nu)(gamma,delta,nu_max,N)
    
    p_N_AP_arr, updates = theano.scan(
        fn = lambda N, gamma, delta, nu_max : integrateOut(p_nu(nu,gamma,delta,nu_max)*poisson_spikes(nu,N,T_total),nu)(N,gamma,delta,nu_max), 
        sequences=[N_AP],
        non_sequences=[gamma,delta,nu_max],
    )
    
    return p_N_AP_arr

theano.gradient.verify_grad(p_N_AP_fun,[[0.,3.,5.],1.5,4.8,30.],rng=rng,mode='DebugMode')

In [None]:
p_N_AP = integrateOut(p_nu(nu,gamma,delta,nu_max)*poisson_spikes(nu,N,T_total),nu)(N,gamma,delta,nu_max)
func_p = theano.function([N_AP,gamma,delta,nu_max],p_N_AP)
func_vals = func_p([0,3,5,10],1.2,4.8,30.)
print(func_vals)


pGrad = T.jacobian(p_N_AP,[gamma,delta,nu_max],consider_constant=[N])
funcGrad = theano.function([N,gamma,delta,nu_max],pGrad) ### somehow in here, copies are made, which dont fit ..
grad_vals = funcGrad(0,1.2,4.8,30.)
print(grad_vals)

In [239]:
p_N_AP = integrateOut(p_nu(nu,gamma,delta,nu_max)*poisson_spikes(nu,N,T_total),nu)(N,gamma,delta,nu_max)
pGrad = T.jacobian(p_N_AP,[gamma,delta,nu_max],consider_constant=[N])
funcGrad = theano.function([N,gamma,delta,nu_max],pGrad) ### somehow in here, copies are made, which dont fit ..
grad_vals = funcGrad(0,1.2,4.8,30.)
print(grad_vals)

[make node (fun)]: inputs: (N, gamma, delta, nu_max), fvars:[N, gamma, delta, nu_max]
calc grad with inputs [N, gamma, delta, nu_max]
gradF: [Elemwise{add,no_inplace}.0, Elemwise{add,no_inplace}.0, Elemwise{add,no_inplace}.0, Elemwise{add,no_inplace}.0]
[make node (fun)]: inputs: (N, gamma, delta, nu_max), fvars:[N, gamma, delta, nu_max]
[make node (fun)]: inputs: (N, gamma, delta, nu_max), fvars:[N, gamma, delta, nu_max]
[make node (fun)]: inputs: (N, gamma, delta, nu_max), fvars:[N, gamma, delta, nu_max]
[make node (fun)]: inputs: (N, gamma, delta, nu_max), fvars:[N, gamma, delta, nu_max]
[perform (fun)]: inputs: [array(0.), array(1.2), array(4.8), array(30.)], fvars:[N, gamma, delta, nu_max]
(array(0.), array(1.2), array(4.8), array(30.))
[perform (fun)]: inputs: [array(0.), array(1.2), array(4.8), array(30.)], fvars:[N, gamma, delta, nu_max]
(array(0.), array(1.2), array(4.8), array(30.))
[perform (fun)]: inputs: [array(0.), array(1.2), array(4.8), array(30.)], fvars:[N, gamma, del