In [None]:
import numpy as np
import warnings
    
def calcium_dynamics(S = None,varargin = None): 
    # [CB,C,F] = calcium_dynamics(S, cal_params)
    
    # Function to simulate the underlying calcium dynamics with either a
# single-occupancy or quad-occupancy model of protein <=> Ca2+
# interactions. The inputs to this model are:
#  - S           - A Kx(nt) array of the spiking activity at each time-step
#                  and for each neuron
#  - cal_params  - Struct with parameters for the calcium simulation
#    .ext_rate   - Extrusion rate for the (default = 1800)
#    .ca_bind    - Calcium binding constant (default = 110)
#    .ca_rest    - Resting-state calcium concentration (default = 50e-9)
#    .ind_con    - Indicator concentration (default = 200e-6)
#    .ca_dis     - Calcium disassociation constant (default = 290e-9)
#    .ca_sat     - Optional calcium saturation parameter(default = 1)
#    .sat_type   - Type of dynamics to simulate (default = 'double')
#    .dt         - Time-trace sampling rate - should be at least the video
#                  sampling rate (default = 1/30)
#    .ca_amp     - Calcium transient amplitude (default = 0.09 for GCaMP6;
#                  default = 0.05 for GCaMP3)
#    .t_on       - Rising time-constant for calcium transient (default =
#                  0.1 for GCaMP6; default = 1 for GCaMP3)
#    .t_off      - Falling time-constant for calcium transient (default =
#                  1.5 for GCaMP6; default = 1 for GCaMP3)
#    .a_bind     - Binding rate for more detailed simulation (default =
#                  3.5)
#    .a_ubind    - Unbinding rate for more detailed simulation (default =
#                  7)
    
    # The ouputs from this model are
#  - CB - A Kx(nt) array of the bound calcium concentration at each
#         time-step and for each neuron
#  - C  - A Kx(nt) array of the total calcium concentration at each
#         time-step and for each neuron
#  - F  - A Kx(nt) array of the overall fluorescence at each time-step and
#         for each neuron
    
    # 2017 - Adam Charles
    
    ##########################################################################
## Input Parsing
    
    if len(varargin) < 2:
        cal_params = struct
    else:
        cal_params = varargin[0]
    
    if len(varargin) < 3:
        prot_type = 'GCaMP6f'
    else:
        prot_type = varargin[2]
    
    if len(varargin) < 4:
        over_samp = 1
    else:
        over_samp = varargin[3]
    
    if len(over_samp)==0:
        over_samp = 1
    
    if len(varargin) < 5:
        ext_mult = 1
    else:
        ext_mult = varargin[4]
    
    cal_params = check_cal_params(cal_params,prot_type)
    
    # Extract necessary params
    ext_rate = ext_mult * cal_params.ext_rate
    ca_bind = cal_params.ca_bind
    ca_rest = cal_params.ca_rest
    ind_con = cal_params.ind_con
    ca_dis = cal_params.ca_dis
    ca_sat = cal_params.ca_sat
    sat_type = cal_params.sat_type
    dt = cal_params.dt
    a_bind = cal_params.a_bind
    a_ubind = cal_params.a_ubind
    ca_amp = cal_params.ca_amp
    t_on = cal_params.t_on
    t_off = cal_params.t_off
    ##########################################################################
## Simulate the calcium iteratively
    if (over_samp > 1):
        S = reshape(np.array([[S],[np.zeros(((over_samp - 1) * S.shape[1-1],S.shape[2-1]))]]),S.shape[1-1],[])
    
    C = np.zeros((S.shape[1-1],S.shape[2-1],'single'))
    
    C[:,1] = np.amax(ca_rest,S(:,1))
    
    if str(sat_type) == str('single'):
        oversampFlag = 1
        a = a_bind(1)
        b = a_ubind(1)
        r = 0
        CB = 0 * C
        CB[:,1] = np.amin(r(r >= 0))
        for kk in np.arange(2,S.shape[2-1]+1).reshape(-1):
            C[:,kk] = C(:,kk - 1) + dt * b * CB(:,kk - 1) + (- dt * ext_rate * (C(:,kk - 1) - CB(:,kk - 1) - ca_rest) + S(:,kk)) / (1 + ca_bind + (ind_con * ca_dis) / (C(:,kk - 1) + ca_dis) ** 2)
            if (ca_sat < 1) and (ca_sat >= 0):
                C[:,kk] = np.amin(C(:,kk),ca_dis * ca_sat / (1 - ca_sat))
            CB[:,kk] = CB(:,kk - 1) + dt * (- b * CB(:,kk - 1) + np.multiply(a * (C(:,kk - 1) - CB(:,kk - 1)),(ind_con - CB(:,kk - 1))))
    else:
        if str(sat_type) == str('Ca_DE'):
            oversampFlag = 0
            a = a_bind(1) * 100 * dt
            b = a_ubind(1) * 100 * dt
            for kk in np.arange(2,S.shape[2-1]+1).reshape(-1):
                C[:,kk] = C(:,kk - 1) + (- dt * ext_rate * (C(:,kk - 1) - ca_rest) + S(:,kk)) / (1 + ca_bind + (ind_con * ca_dis) / (C(:,kk - 1) + ca_dis) ** 2)
                if (ca_sat < 1) and (ca_sat >= 0):
                    C[:,kk] = np.amin(C(:,kk),ca_dis * ca_sat / (1 - ca_sat))
            clear('S')
            h_ca = single(mk_doub_exp_ker(t_on,t_off,ca_amp,dt))
            TMP = convn(C(1,:) - ca_rest,np.transpose(h_ca),'full') + ca_rest
            TMP = TMP(:,np.arange(1,end()+over_samp,over_samp))
            CB = np.zeros(C.shape[1-1],TMP.shape[2-1],'like',C)
            for kk in np.arange(1,C.shape[1-1]+1).reshape(-1):
                TMP = convn(C(kk,:) - ca_rest,np.transpose(h_ca),'full') + ca_rest
                CB[kk,:] = TMP(np.arange(1,end()+over_samp,over_samp))
            C = C(:,np.arange(1,end()+over_samp,over_samp))
            CB = CB(:,np.arange(1,C.shape[2-1]+1))
        else:
            if str(sat_type) == str('double'):
                oversampFlag = 1
                a = a_bind
                b = a_ubind
                if np.asarray(a).size == 1:
                    a = np.array([a,a])
                else:
                    a = a(np.arange(1,2+1))
                if np.asarray(b).size == 1:
                    b = np.array([b,b])
                else:
                    b = b(np.arange(1,2+1))
                CB1 = 0 * C
                CB2 = 0 * C
                for kk in np.arange(2,S.shape[2-1]+1).reshape(-1):
                    C[:,kk] = C(:,kk - 1) + dt * (b(1) * CB1(:,kk - 1) + b(2) * CB2(:,kk - 1)) + (- dt * ext_rate * (C(:,kk - 1) - CB1(:,kk - 1) - CB2(:,kk - 1) - ca_rest) + S(:,kk)) / (1 + ca_bind + (ind_con * ca_dis) / (C(:,kk - 1) + ca_dis) ** 2)
                    if (ca_sat < 1) and (ca_sat >= 0):
                        C[:,kk] = np.amin(C(:,kk),ca_dis * ca_sat / (1 - ca_sat))
                    CB1[:,kk] = CB1(:,kk - 1) + dt * (- b(1) * CB1(:,kk - 1) + np.multiply(a(1) * (C(:,kk - 1) - CB1(:,kk - 1) - CB2(:,kk - 1)),(ind_con - CB1(:,kk - 1) - CB2(:,kk - 1))))
                    CB2[:,kk] = CB2(:,kk - 1) + dt * (- b(2) * CB2(:,kk - 1) + np.multiply(a(2) * (C(:,kk - 1) - CB1(:,kk - 1) - CB2(:,kk - 1)),(ind_con - CB1(:,kk - 1) - CB2(:,kk - 1))))
                CB = CB1 + CB2
            else:
                raise Exception('Unknown model!')
    
    ##########################################################################
## Output parsing
    
    if (oversampFlag):
        C = C(:,np.arange(1,end()+over_samp,over_samp))
        CB = CB(:,np.arange(1,end()+over_samp,over_samp))
    
    if nargout == 1:
        varargout[0] = CB
    else:
        if nargout == 2:
            varargout[0] = CB
            varargout[2] = C
        else:
            if nargout == 3:
                varargout[0] = CB
                varargout[2] = C
                if str(sat_type) == str('single'):
                    CB = CB + ca_rest + (b / a) * CB / (ind_con - CB)
                F = sat_nonlin(CB,prot_type)
                varargout[3] = F
    
    return varargout
    
    
def sat_nonlin(CB = None,prot_type = None): 
    if np.array(['gcamp6','gcamp6f']) == prot_type.lower():
        F0 = 1
        F = 25.2 * (1.0 / (1 + (2.9e-07 / CB) ** 2.7))
    else:
        if 'gcamp6s' == prot_type.lower():
            F0 = 1
            #            F = 53.8*(1./(1 + (147e-9./CB).^2.45));                        # Hill equation values taken from Dana et al. 2019
            F = 27.2 * (1.0 / (1 + (1.47e-07 / CB) ** 2.45))
        else:
            if 'gcamp3' == prot_type.lower():
                F0 = 2
                F = 12 * (1.0 / (1 + (2.87e-07 / CB) ** 2.52))
            else:
                if np.array(['ogb1','ogb-1']) == prot_type.lower():
                    F0 = 1
                    F = 14 * (1.0 / (1 + 2.5e-07 / CB))
                else:
                    if np.array(['gcamp6-rs09','gcamp6rs09']) == prot_type.lower():
                        F0 = 1.4
                        F = 25 * (1.0 / (1 + (5.2e-07 / CB) ** 3.2))
                    else:
                        if np.array(['gcamp6-rs06','gcamp6rs06']) == prot_type.lower():
                            F0 = 1.2
                            F = 15 * (1.0 / (1 + (3.2e-07 / CB) ** 3))
                        else:
                            if 'jgcamp7f' == prot_type.lower():
                                F0 = 1
                                F = 30.2 * (1.0 / (1 + (1.74e-07 / CB) ** 2.3))
                            else:
                                if 'jgcamp7s' == prot_type.lower():
                                    F0 = 1
                                    F = 40.4 * (1.0 / (1 + (6.8e-08 / CB) ** 2.49))
                                else:
                                    if 'jgcamp7b' == prot_type.lower():
                                        F0 = 1
                                        F = 22.1 * (1.0 / (1 + (8.2e-08 / CB) ** 3.06))
                                    else:
                                        if 'jgcamp7c' == prot_type.lower():
                                            F0 = 1
                                            F = 145.6 * (1.0 / (1 + (2.98e-07 / CB) ** 2.44))
                                        else:
                                            warnings.warn('Unknown protien type. Defaultin to GCaMP6f...\n')
                                            F0 = 1
                                            F = 25.2 * (1.0 / (1 + (2.9e-07 / CB) ** 2.7))
    
    F = F0 + F0 * F
    
    return F
    
    ##########################################################################
##########################################################################
    return varargout