## AIMS

- Model how synaptic strengths are dynamically modified based on pre and post synaptic (synapse -> point at which 2 neurons connect) spikes

- Implementing various STDP rules to investigate connectivity, learning and memory formation

- Analyze impact of different STDP parameters and network topologies 

## to do list :
    - 80 excitatory neurons DONE
    - 20 inhibitory neurons DONE
    - raster plot DONE
    - either use 2D or 3D spatially structured network 
        (start with LIF if possible) DONE
    - Simulate pre learning VS Simulate post learning (how does synapse strength change in LIF? )
    - Investigate Hodgkin-Huxley model (how does synapse strength change?)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import nest 
import nest.voltage_trace
nest.set_verbosity("M_WARNING")
nest.ResetKernel()

# PRE LEARNING

In [None]:
params_dict = {"I_e": 200.0, "tau_m": 20.0}

excitatory_neuron_nodes = nest.Create(model="iaf_psc_alpha", n=80, params=params_dict)
inhibitory_neuron_nodes = nest.Create(model="iaf_psc_alpha", n=20, params=params_dict)

In [None]:
print(excitatory_neuron_nodes)

In [None]:
print(inhibitory_neuron_nodes)

### measurement tool

In [None]:
voltmeter = nest.Create("voltmeter")
print(voltmeter)

In [None]:
spike_recorder = nest.Create("spike_recorder")  # https://nest-simulator.readthedocs.io/en/v2.18.0/models/detector.html -> called "spike_dectector" in documentation
print(spike_recorder)

In [None]:
weight = 20.0
delay = 1.0
p = 0.2

connection_spec = {"rule": "pairwise_bernoulli", "p": p}

nest.Connect(excitatory_neuron_nodes,
             excitatory_neuron_nodes,
             connection_spec,
             syn_spec={
                 "weight": weight,
                 "delay": delay
             })

nest.Connect(inhibitory_neuron_nodes,
             inhibitory_neuron_nodes,
             connection_spec,
             syn_spec={
                 "weight": -50.0,
                 "delay": delay
             })

nest.Connect(excitatory_neuron_nodes, 
             inhibitory_neuron_nodes, 
             connection_spec,
             syn_spec={
                "weight": weight,
                "delay": delay
            })


nest.Connect(inhibitory_neuron_nodes,
             excitatory_neuron_nodes,
             connection_spec,
             syn_spec={
                "weight": -50.0,
                "delay": delay
            })

## Parameters definition

- I_e -> External current 
- tau_m -> How quickly neuron's membrane potential decays back to stabilised rate
- synapse_model -> Specifying use of STDP model so that the synapses (connections) can learn
- weight -> Synapse strength
- delay -> time it takes for one spike to travel from one neuron to another
- alpha -> learning rate for changing the synaptic weight
- tau_plus -> Time it takes for synapse to strengthen after fire?


In [None]:
nest.Connect(voltmeter, excitatory_neuron_nodes)

nest.Connect(voltmeter, inhibitory_neuron_nodes)

nest.Connect(excitatory_neuron_nodes, spike_recorder)
nest.Connect(inhibitory_neuron_nodes, spike_recorder)

In [None]:
nest.Simulate(400.0)

In [None]:
nest.SetStatus(voltmeter, params={})

In [None]:
nest.voltage_trace.from_device(voltmeter)
# plt.legend().set_visible(False)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()

In [None]:
import nest.raster_plot
nest.raster_plot.from_device(spike_recorder, hist=True)
plt.show()

In [None]:
initial_connections = nest.GetConnections(source=excitatory_neuron_nodes, target=inhibitory_neuron_nodes)
initial_weights = nest.GetStatus(initial_connections, 'weight')

# POST LEARNING

In [None]:
# nest.ResetKernel()

In [None]:
params_dict_2 = {"I_e": 200.0, "tau_m": 20.0}

syn_spec_stdp = {
    "model": "stdp_synapse",
    "weight": 20.0,
    "delay": 1.0
}

excitatory_neuron_nodes_2 = nest.Create(model="iaf_psc_alpha", n=80, params=params_dict_2)
inhibitory_neuron_nodes_2 = nest.Create(model="iaf_psc_alpha", n=20, params=params_dict_2)

##  connection excitatory-excitatory, inhibitory-inhibitory, excitatory-inhibitory, inhibitory-excitatory 


- Note: inhibitory must be set to negative reduce activity or other neurons and ensure stable activity

In [None]:

connection_spec_2 = {"rule": "pairwise_bernoulli", "p": 0.2} # sparsely connected network as each neuron is set to connection chance of 20%


nest.Connect(excitatory_neuron_nodes_2,
             excitatory_neuron_nodes_2,
             connection_spec_2,
             syn_spec={
                 "synapse_model": "stdp_synapse",
                 "weight": syn_spec_stdp["weight"],
                 "delay": syn_spec_stdp["delay"]
             })

nest.Connect(inhibitory_neuron_nodes_2,
             inhibitory_neuron_nodes_2,
             connection_spec_2,
             syn_spec={
                "synapse_model": "stdp_synapse",
                "weight": -50.0,
                "delay": syn_spec_stdp["delay"]
            })


nest.Connect(excitatory_neuron_nodes_2,
             inhibitory_neuron_nodes_2,
             connection_spec_2,
             syn_spec={
                    "synapse_model": "stdp_synapse",
                    "weight": syn_spec_stdp["weight"],
                    "delay": syn_spec_stdp["delay"]
            })

nest.Connect(inhibitory_neuron_nodes_2,
            excitatory_neuron_nodes_2,
            connection_spec_2,
            syn_spec={
                "synapse_model": "stdp_synapse",
                "weight": -50.0,
                "delay": syn_spec_stdp["delay"]
            })

In [None]:
spike_recorder_2 = nest.Create("spike_recorder")
print(spike_recorder_2)

In [None]:
stimulus_2 = nest.Create("poisson_generator", params={"rate": 1000.0})
nest.Connect(stimulus_2, excitatory_neuron_nodes_2[:10]) 


nest.Connect(excitatory_neuron_nodes_2, spike_recorder_2)
nest.Connect(inhibitory_neuron_nodes_2, spike_recorder_2)

In [None]:
nest.Simulate(400)

In [None]:
nest.raster_plot.from_device(spike_recorder_2, hist=True)
plt.show()

In [None]:
final_connections = nest.GetConnections(source=excitatory_neuron_nodes_2, target=inhibitory_neuron_nodes_2)
final_weights = nest.GetStatus(final_connections, 'weight')

In [None]:
print(f"Initial mean weight (PRE LEARNING PHASE): {np.mean(initial_weights)}")
print(f"Final mean weight: {np.mean(final_weights)}")

## current state: the synaptic weight has increased

# Training the SNN to understand patterns through temporal-sequential learning

# Each sequence will be trained on a separate group of excitatory neurons -> 80 excitatory neurons / 3 = 26.6 - 27 per group

# Sequence times must be > 400 ms because we are already simulating for 0-400ms above

In [None]:
sequence_times = {
    "SEQUENCE_1": [410, 510, 610, 710],
    "SEQUENCE_2": [430, 539, 630, 730],
    "SEQUENCE_3": [450, 550, 650, 750]
}

In [None]:
group_A = excitatory_neuron_nodes[:27]
group_B = excitatory_neuron_nodes[27:54]
group_C = excitatory_neuron_nodes[54:79]

In [None]:
print(group_A)
print(group_B)
print(group_C)

## Now to use `spike_generator` to generate spikes from our sequence array of spike times 

In [None]:
seq_1_generated = nest.Create("spike_generator", params={"spike_times": sequence_times["SEQUENCE_1"]})
seq_2_generated = nest.Create("spike_generator", params={"spike_times": sequence_times["SEQUENCE_2"]})
seq_3_generated = nest.Create("spike_generator", params={"spike_times": sequence_times["SEQUENCE_3"]})

In [None]:
print(seq_1_generated)
print(seq_2_generated)
print(seq_3_generated)

In [None]:
nest.Connect(seq_1_generated, group_A)
nest.Connect(seq_2_generated, group_B)
nest.Connect(seq_3_generated, group_C)

## Creating separate spike recorders for each group on excitatory neurons to analyze separately

In [None]:
spike_recorder_A = nest.Create("spike_recorder")
spike_recorder_B = nest.Create("spike_recorder")
spike_recorder_C = nest.Create("spike_recorder")

In [None]:
nest.Connect(group_A, spike_recorder_A)
nest.Connect(group_B, spike_recorder_B)
nest.Connect(group_C, spike_recorder_C)

In [None]:
nest.Simulate(400)

In [None]:
nest.raster_plot.from_device(spike_recorder_A, hist=True)
plt.title("Group A (Sequence 1)")
plt.show()


nest.raster_plot.from_device(spike_recorder_B, hist=True)
plt.title("Group B (Sequence 2)")
plt.show()


nest.raster_plot.from_device(spike_recorder_C, hist=True)
plt.title("Group C (Sequence 3)")
plt.show()

In [None]:
# spike_recorder.get("events")

In [None]:
print("Spike generator A times:", nest.GetStatus(seq_1_generated, 'spike_times'))
print("Spike generator B times:", nest.GetStatus(seq_2_generated, 'spike_times'))
print("Spike generator C times:", nest.GetStatus(seq_3_generated, 'spike_times'))