-
Notifications
You must be signed in to change notification settings - Fork 12
/
simulation.jl
85 lines (77 loc) · 3.28 KB
/
simulation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
function simplify_simulation(sys, time)
odesys = convert(ODESystem, sys)
t_val = ustrip(Float64, ms, time)
return t_val, structural_simplify(odesys)
end
"""
$(TYPEDSIGNATURES)
Compile and run a simulation of a single `neuron` or `network` of neurons for a specified
duration, `time`.
If `return_system == true`, returns a simplified `ODESystem` instead.
"""
function Simulation(neuron::AbstractCompartmentSystem, time::Time; return_system = false,
jac = false, sparse = false,
parallel = Symbolics.SerialForm())
t_val, simplified = simplify_simulation(neuron, time)
if return_system
return simplified
else
@info repr("text/plain", simplified)
return ODEProblem(simplified, [], (0.0, t_val), []; jac, sparse, parallel)
end
end
struct NetworkParameters{T}
ps::Vector{T}
topology::NetworkTopology
end
Base.getindex(x::NetworkParameters, i) = x.ps[i]
topology(x::NetworkParameters) = getfield(x, :topology)
function get_weights(integrator, model)
topo = topology(integrator.p)
return graph(topo)[model]
end
function Simulation(network::NeuronalNetworkSystem, time::Time; return_system = false,
jac = false, sparse = false, parallel = Symbolics.SerialForm(), continuous_events = false,
refractory = true)
t_val, simplified = simplify_simulation(network, time)
return_system && return simplified
if !any(iseventbased.(synaptic_systems(network)))
return ODEProblem(simplified, [], (0.0, t_val), []; jac, sparse, parallel)
else
cb = generate_callback(network, simplified; continuous_events, refractory)
prob = ODEProblem(simplified, [], (0.0, t_val), []; callback = cb, jac, sparse, parallel)
remake(prob; p = NetworkParameters(prob.p, get_topology(network) ))
end
end
# if continuous, condition has vector cb signature: cond(out, u, t, integrator)
function generate_callback_condition(network, simplified; continuous_events, refractory)
voltage_indices = map_voltage_indices(network, simplified; roots_only = true)
if continuous_events
return ContinuousSpikeDetection(voltage_indices)
else # discrete condition for each compartment
return [DiscreteSpikeDetection(voltage_index, refractory) for voltage_index in voltage_indices]
end
end
function generate_callback_affects(network, simplified)
spike_affects = []
for sys in synaptic_systems(network)
push!(spike_affects, SpikeAffect(sys, network, simplified))
end
tailcall = nothing # placeholder for voltage reset
return NetworkAffects(spike_affects, tailcall)
end
function generate_callback(network, simplified; continuous_events, refractory)
cb_condition = generate_callback_condition(network, simplified; continuous_events, refractory)
cb_affect = generate_callback_affects(network, simplified)
if continuous_events
return VectorContinuousCallback(cb_condition, cb_affect,
length(cb_condition.voltage_indices))
else
affects = []
for i in 1:length(root_compartments(get_topology(network)))
push!(affects, Base.Fix2(cb_affect, i))
end
callbacks = [DiscreteCallback(x,y) for (x,y) in zip(cb_condition, affects)]
return CallbackSet(callbacks...)
end
end