In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

'''
STDP modulated with reward

Adapted from Fig. 1c of:
Eugene M. Izhikevich 
Solving the distal reward problem through linkage of STDP and dopamine signaling. 
Cerebral cortex 17, no. 10 (2007): 2443-2452.

Note:
The variable "mode" can switch the behavior of the synapse from "Classical STDP" to "Dopamine modulated STDP".

Author: Guillaume Dumas (Institut Pasteur)
Date: 2018-08-24
'''
from brian2 import *

# Parameters
simulation_duration = 240 * second

## Neurons
taum = 10*ms
Ee = 0*mV
vt = -54*mV
vr = -60*mV
El = -74*mV
taue = 5*ms

## STDP
taupre = 20*ms
taupost = taupre
gmax = .01
dApre = .01
dApost = -dApre * taupre / taupost * 1.05
dApost *= gmax
dApre *= gmax

## Dopamine signaling
tauc = 1000*ms
taud = 200*ms
taus = 1*ms
epsilon_dopa = 5e-3

# Setting the stage
network = Network()

## Stimuli section
num_neurons = 100

input_rate = 1*Hz
input = PoissonGroup(num_neurons, input_rate)
network.add(input)

neurons = NeuronGroup(num_neurons, '''dv/dt = (ge * (Ee-vr) + El - v) / taum : volt
                                      dge/dt = -ge / taue : 1''',
                      threshold='v>vt', reset='v = vr',
                      method='linear')
neurons.v = vr
network.add(neurons)

neurons_monitor = SpikeMonitor(neurons, ['v'], record=True)
network.add(neurons_monitor)

synapse = Synapses(input, neurons, 
                   model='''s: volt''',
                   on_pre='v += s')
synapse.connect(i=list(range(0, num_neurons)), j=list(range(0, num_neurons)))
synapse.s = 100. * mV

network.add(synapse)

## STDP section
epsilon = 0.1 # sparseness of synaptic connections

synapse_stdp = Synapses(neurons, neurons,
                   model='''mode: 1
                         dc/dt = -c / tauc : 1 (clock-driven)
                         dd/dt = -d / taud : 1 (clock-driven)
                         ds/dt = mode * c * d / taus : 1 (clock-driven)
                         dApre/dt = -Apre / taupre : 1 (event-driven)
                         dApost/dt = -Apost / taupost : 1 (event-driven)''',
                   on_pre='''ge += s
                          Apre += dApre
                          c = clip(c + mode * Apost, -gmax, gmax)
                          s = clip(s + (1-mode) * Apost, -gmax, gmax)
                          ''',
                   on_post='''Apost += dApost
                          c = clip(c + mode * Apre, -gmax, gmax)
                          s = clip(s + (1-mode) * Apre, -gmax, gmax)
                          ''',
                   method='euler'
                   )
synapse_stdp.connect(p=epsilon)
if not((0, 1) in zip(synapse_stdp.i, synapse_stdp.j)):
    synapse_stdp.connect(i=0, j=1)

k = 0
for i, j in zip(synapse_stdp.i, synapse_stdp.j):
    if ((i,j)==(0,1)):
        break
    k+=1

synapse_stdp.mode = 0
synapse_stdp.s = 1e-10
synapse_stdp.c = 1e-10
synapse_stdp.d = 0

network.add(synapse_stdp)

synapse_stdp_monitor = StateMonitor(synapse_stdp, ['s', 'c', 'd'], record=[k])
network.add(synapse_stdp_monitor)

## Dopamine signaling section

check_reward = NeuronGroup(1, '''dunlock/dt = (-1*second-unlock*0.001)/(1*second) : second
                                 spike : second''',
                          threshold='spike > 1*ms', reset='spike = 0*second', method='linear')
network.add(check_reward)
check_monitor = StateMonitor(check_reward, 'unlock', record=True)
check_monitor2 = SpikeMonitor(check_reward, 'spike', record=True)
network.add(check_monitor)
network.add(check_monitor2)

check_pre = Synapses(neurons, check_reward, model='''''', on_pre='unlock_post = 101*ms', method='exact')
check_post = Synapses(neurons, check_reward, model='''''', on_pre='spike_post = clip(unlock_post, 0*second, 101*ms)', method='exact')
check_pre.connect(i=0, j=0)
check_post.connect(i=1, j=0)
network.add(check_pre)
network.add(check_post)

dopamine = NeuronGroup(1, '''v : volt''', threshold='v>1*volt', reset='v=0*volt')
network.add(dopamine)
dopamine_trigger = Synapses(check_reward, dopamine, model='''''', on_pre='v_post += 2*volt', method='exact')
dopamine_trigger.connect(p=1.)
network.add(dopamine_trigger)


reward = Synapses(dopamine, synapse_stdp, model='''''',
                             on_pre='''d_post += epsilon_dopa''',
                             method='exact')
reward.connect(p=1.)
reward.delay='1*second'
network.add(reward)

# Simulation
## Classical STDP
#synapse_stdp.mode = 0

## Dopamine modulated STDP
synapse_stdp.mode = 1
network.run(simulation_duration, report='text')


Dot product of non row/column vectors has been deprecated since SymPy
1.2. Use * to take matrix products instead. See
https://github.com/sympy/sympy/issues/13815 for more info.

  useinstead="* to take matrix products").warn()


Starting simulation at t=0. s for a duration of 240. s
3.5643000000000002 (1%) simulated in 10s, estimated 11m 3s remaining.
7.19 (2%) simulated in 20s, estimated 10m 48s remaining.
10.4852 (4%) simulated in 30s, estimated 10m 57s remaining.
13.478800000000001 (5%) simulated in 40s, estimated 11m 12s remaining.
16.6324 (6%) simulated in 50s, estimated 11m 12s remaining.
19.9739 (8%) simulated in 1m 0s, estimated 11m 1s remaining.
23.3626 (9%) simulated in 1m 10s, estimated 10m 49s remaining.
26.75 (11%) simulated in 1m 20s, estimated 10m 38s remaining.
30.1029 (12%) simulated in 1m 30s, estimated 10m 28s remaining.
33.5236 (13%) simulated in 1m 40s, estimated 10m 16s remaining.
37.266400000000004 (15%) simulated in 1m 50s, estimated 9m 58s remaining.
40.989000000000004 (17%) simulated in 2m 0s, estimated 9m 43s remaining.
44.6858 (18%) simulated in 2m 10s, estimated 9m 28s remaining.
48.4196 (20%) simulated in 2m 20s, estimated 9m 14s remaining.
52.1516 (21%) simulated in 2m 30s, estim

In [None]:
# Visualisation
spikes_0 = [neurons_monitor.t[j] for j in range(len(neurons_monitor.t)) if neurons_monitor.i[j]==0]
spikes_1 = [neurons_monitor.t[j] for j in range(len(neurons_monitor.t)) if neurons_monitor.i[j]==1]

plot(synapse_stdp_monitor.t/second, synapse_stdp_monitor.s.T/gmax, 'g-')
plot(spikes_0/second, [0]*len(spikes_0), 'b+')
plot(spikes_1/second, [0]*len(spikes_1), 'r+')
xlim([0, simulation_duration/second])
ylabel('Synaptic strength s(t)')
xlabel('Time (s)')
show()

In [None]:
plt.plot(check_monitor.t/second, check_monitor.unlock.T, 'b-')
plt.plot(check_monitor2.t/second, check_monitor2.spike, 'ro')
xlim([0, simulation_duration/second])
ylim([0, 0.01])
ylabel('blabla')
xlabel('Time (s)')
show()

In [None]:
plt.hist(synapse_stdp_monitor.s[-1], density=True)
show()