## Title: Shaping Spiking Neural Network Connectivity with Spatiotemporal Activity
#### Candidate Number: 291089
#### Supervsior:  Dr Johanna Senk
#### Degree: MSc Artificial Intelligence and Adaptive Systems

##### README

To run this simulation you will need to:
- Find 'Define Network Parameters' section and set desired paramters
- Run the network using the run_network function under the 'Running the Network' section.
- Most visualisations will be generated automatically but additional figures are available under the 'Visualisations' section.
- Run the 'Metrics' section to generate metrics after the testing phase has been run.

## Import libraries

In [None]:
import nest
import nest.voltage_trace
import os

import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib.animation as animation

import networkx as nx
import matplotlib.cm as cm
import matplotlib.colors as colors

from scipy.stats import spearmanr
import math
import random

In [None]:
nest.SetKernelStatus({"local_num_threads": 4}) 

## Build Class

In [None]:
class buildNetwork:
    '''
    This class builds the spiking neural network. It sets up the 2D neuron grid, sets up the neurons with the desired parameters,
    and removes self connections. It also creates a folder for results and resets the NEST kernel everytime it is called.

    Parameters
    -----------
    grid_row: int
        Number of rows in the neuron grid.

    grid_col: int
        Number of columns in the neuron grid.

    title: str
        Title used for naming files.

    structural_plasticity: bool
        If True, enables structural plasticity in the network.

    reverse: bool
        If True, the stimulation pattern is applied in reverse order.

    synapse_weight: float
        Initial synaptic weight for the connections.

    alpha: float
        Alpha parameter.

    ld: float
        Lambda parameter.

    prob: float
        Bernouilli probability for connectivity within the mask.

    radius: float
        Mask radius.

    spike_delay: float
        Delay between column spikes.

    spike_weight: float
        Weight for training spike generator.

    synapse_delay: float
        Delay for synapses.

    multiplier: float
        Disconnection multiplier for structural plasticity.

    test_weight: float
        Weight for testing spike generator.



    Attributes
    -----------
    total_neurons: int
        Total neurons in the grid.

    deleted: list
        Stores IDs of pruned connections.

    all_connected:list
        All connections that were made during the connection rule.



    '''
    def __init__(self,grid_rows,grid_cols, title, structural_plasticity,reverse, synapse_weight, alpha, ld, prob, radius, spike_delay, spike_weight,synapse_delay,multiplier,test_weight):
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.structural = structural_plasticity
        self.reverse = reverse
        self.total_neurons = grid_rows * grid_cols
        self.synapse_weight = synapse_weight
        self.alpha = alpha
        self.ld = ld
        self.bern_prob = prob
        self.radius = radius
        self.spike_delay = spike_delay
        self.spike_weight = spike_weight
        self.synapse_delay = synapse_delay
        self.multiplier = multiplier
        self.test_weight = test_weight
        self.title = title
        self.all_connected = []
        self.neuron_grid = self.spike_recorder = self.voltmeter = self.layer = self.time = self.threshold = None
        
        self.createFolder()
        nest.ResetKernel()
        self.setUp()
        
        
    def createFolder(self):
        '''
        A folder for all metrics and figures is created.
        '''
        directory_name = f"{self.title}"
        try:
            os.mkdir(directory_name)
            print(f"Directory '{directory_name}' created successfully.")
        except FileExistsError:
            print(f"Directory '{directory_name}' already exists.")
        
        
    def setUp(self):
        '''
        This function sets up the grid by assigning positions, connections and synapse parameters. 
        Self connections are also removed.

        '''
        #The neuron grid is created
        self.layer = nest.Create('iaf_psc_exp',
                    positions=nest.spatial.grid(shape=[self.grid_rows, self.grid_cols], extent=[self.grid_rows, self.grid_cols], edge_wrap=True))
        
        nest.SetDefaults("stdp_synapse", {"Wmax": 1000.0, "alpha":self.alpha})
        
        #The connectivity mask is created
        conn_dict = {"rule": "pairwise_bernoulli", "p": self.bern_prob, "mask": {"circular": {"radius": self.radius}}}
        
        #The synapse parameters are set
        self.syn_dict = {"synapse_model": "stdp_synapse","weight": self.synapse_weight, "delay": self.synapse_delay}        
        nest.Connect(self.layer, self.layer, conn_dict, self.syn_dict)
        
        #Neurons are sorted in configurations that will be used throughout the model
        self.neuron_grid = np.array(self.layer).reshape((self.grid_rows, self.grid_cols))
        self.all_neurons = sorted(self.neuron_grid.ravel().tolist())
        self.neuron_ids = set(self.layer.tolist())
        self.column_ids = list(map(list, zip(*self.neuron_grid)))
        
        #Self connections are removed
        for i in nest.GetConnections():
            if i.source == i.target:
                nest.Disconnect(i)
                
        #Neuron grid is plotted
        fig = nest.PlotLayer(self.layer, nodesize=80)
        ctr = nest.FindCenterElement(self.layer)
        nest.PlotTargets(
            ctr,
            self.layer,
            fig=fig,
            src_size=250,
            tgt_color="red",
            tgt_size=20,
            mask_color="red",
            probability_cmap="Greens",
        )

    
    def draw_graph(self, timestep = 0):
        '''
        This function draws the network connectivity graph and assigns the colours of the arrows depending on the weight of the 
        synapse and widths depending on the strength.

            Parameters:
                timestep: int
                The specific iteration of the training time the graph connectivity corresponds to.

        '''
        G = nx.DiGraph()  
        positions = {}
        
        #The neurons are arranged in a grid
        for idx, neuron_id in enumerate(self.neuron_ids):
            if neuron_id <= self.total_neurons:
                row = idx // self.grid_cols
                col = idx % self.grid_cols
                positions[neuron_id] = (col, -row) 
                G.add_node(neuron_id)
        
        edge_colours = []
        for conn in reversed(list(nest.GetConnections())):
            s, t, w = conn.source, conn.target, conn.weight
            if s != t and s<= self.total_neurons and t<=self.total_neurons:
                G.add_edge(s, t, weight = w)
                #Each connection is assigned a colour based on their weight
                if w > 310:
                    edge_colours.append('blue')
                elif self.threshold != None and w <= self.threshold:
                    edge_colours.append('red')
                else:
                    edge_colours.append('gray')

        #The widths of the arrows are calculated from the synaptic weight
        widths = [np.exp((G[u][v]['weight'])/100)/10 for u, v in G.edges()]
        
        nx.draw(G, pos=positions, width = widths, with_labels=True, edge_color = edge_colours, node_size=800,font_color='white', font_size = 20, node_color='black', arrows=True)
        plt.title(f"Network Connectivity at Time: {timestep}ms")
        
        file_path = os.path.join(self.title, f"graph_at_time_{timestep}_{self.title}.png")
        plt.savefig(file_path)
        plt.show()    

    def draw_deleted_graph(self, timestep = 0):
        '''
        Draws the deleted connections from the network connectivity graph.

            Parameters:
                timestep: int
                The specific iteration of the training time the graph corresponds to.

        '''
        G = nx.DiGraph()  
        positions = {}
        
        #The neurons are arranged in a grid
        for idx, neuron_id in enumerate(self.neuron_ids):
            if neuron_id <= self.total_neurons:
                row = idx // self.grid_cols
                col = idx % self.grid_cols
                positions[neuron_id] = (col, -row) 
                G.add_node(neuron_id)
        
        edge_colours = []
        
        #All deleted connections are assigned red arrows
        for s,t in self.deleted:
            G.add_edge(s,t)
            edge_colours.append('red')
                        
        plt.figure(figsize=(10, 8))
        nx.draw(G, pos=positions,width=5, with_labels=True, edge_color = edge_colours, node_size=3000, font_color='white',font_size = 36,node_color='black', arrows=True)
        plt.title(f"Deleted Connections at Time: {timestep}ms",fontsize=36)
        file_path = os.path.join(self.title, f"del_graph_at_time_{timestep}_{self.title}.png")
        plt.savefig(file_path)
        plt.show()
        
    def generate_spike(self,time):
        '''
        This function calculates the specific time that each column has to spike for all of them to receive the same number of 
        spikes given the timing of the spike delay and the total training time. The function then assigns these times to each 
        neuron within a column.

            Parameters:
                time: int
                The total training time.

        '''
        #List of all column spiking times from 10 - training time, 20 ms apart.
        spiking_times  = [n for n in range(10,int(time),int(self.spike_delay))]
        
        #The times are sorted so all columns have the same amount of spikes
        set_spike_times = [[] for _ in range(self.grid_cols)]
        for idx, val in enumerate(spiking_times):
            set_spike_times[idx % self.grid_cols].append(val)
    
        set_spike_times = [innerlist[:len(min(set_spike_times,key=len))] for innerlist in set_spike_times]
        
        #Each column is assigned the same spike generator with the calculated times.
        for i in range(self.grid_cols):
            spike_generator = nest.Create('spike_generator')
            spike_generator.set(spike_times=set_spike_times[i])
            for j in range(self.grid_rows):
                #If the reverse wave is activated the timings are assigned backwards
                if self.reverse == False:
                    nest.Connect(spike_generator, [self.neuron_grid[j][i]], syn_spec={'weight': self.spike_weight})
                else:
                    nest.Connect(spike_generator, [self.neuron_grid[j][self.grid_cols - 1 - i]], syn_spec={'weight': self.spike_weight})
               

    
    def connect_recorders(self):
        '''
        This function connects a voltmeter and a spike recorder to all of the neurons in the grid.

        '''
        self.voltmeter = nest.Create('voltmeter')
        self.spike_recorder = nest.Create("spike_recorder")
        nest.Connect(self.layer, self.spike_recorder)       
        nest.Connect(self.voltmeter, self.all_neurons)
        
                
    def simulate(self,training_time):
        '''
        This function carries out the training phase of the model. The spike generators are assigned and all recording devices 
        are connected. A dataframe with the initial connections is created. The simulation is then run in increments of 100 ms 
        until the total training time is completed. At each iteration: the dataframe is updated with the new weights, a 
        connectivity graph is drawn, a heatmap is drawn, the neurons that spiked are recorded and the membrane potential of each 
        neuron is recorded. In the structural plasticity model the disconnection and/or connection tule is called and the 
        deleted connections graph is drawn.

            Parameters:
                training_time: int
                The total training time.

        '''
        
        #Spike generators are set up
        self.generate_spike(training_time)
        self.connect_recorders() #The voltmeter and spike recorder are connected
        self.time = training_time
        
        #Initial weights are recorded in a dataframe 
        conns = nest.GetConnections() 
        df = pd.DataFrame({'source': conns.source,'target': conns.target,'0': conns.weight})
        self.df_conns = df[df['source'].isin(self.neuron_ids) & df['target'].isin(self.neuron_ids)]
        
        #Initial graphs are drawn
        self.draw_graph()
        self.create_heatmap() 
        
        #Neuron death is optional
        #self.neuronal_death(conns)
        
        for i in range(1,int(training_time/100)+1):
            #100 ms of the training phase is run
            nest.Simulate(100)
            
            #All synaptic weights of the neurons are added to the dataframe
            conns = nest.GetConnections()
            df = pd.DataFrame({'source': conns.source,'target': conns.target,'0': conns.weight})
            df_conns_new = df[df['source'].isin(self.neuron_ids) & df['target'].isin(self.neuron_ids)]
            df_conns_new = df_conns_new.rename(columns={'0': f'{i*100}'})
            new_rows = df_conns_new[~df_conns_new.set_index(['source', 'target']).index.isin(self.df_conns.set_index(['source', 'target']).index)]
            self.df_conns = pd.concat([self.df_conns, new_rows[['source', 'target']]], ignore_index=True)
            self.df_conns = pd.merge(self.df_conns, df_conns_new[['source', 'target', f'{i*100}']], on=['source', 'target'], how='left')
            
            #Metrics are collected
            self.events = nest.GetStatus(self.spike_recorder, "events")[0]
            self.senders, self.times = self.events["senders"], self.events["times"]

            self.voltage_events = nest.GetStatus(self.voltmeter, "events")[0]
            self.voltages, self.voltage_times, self.voltage_senders = self.voltage_events["V_m"], self.voltage_events["times"], self.voltage_events["senders"]
            
            #Graphs are drawn
            self.draw_graph(timestep = i*100)
            self.create_heatmap(timestep = i*100)
            
            #Metrics and figures can be printed at each time step.
            #self.print_metrics(timestep = i*100)
            #self.plot_spikes(timestep = i*100)
            #self.plot_voltage(timestep = i*100)
            
            #If the structural model is on the plasticity rule is implemented
            if self.structural == True:
                self.apply_structural_plasticity_rule_disconnection(conns)
                self.draw_deleted_graph(timestep = i*100)
                #self.apply_structural_plasticity_rule_connection()
              
        
        self.df_conns.to_csv(os.path.join(self.title,f"neuron_connections_{self.title}.csv"), index=False)
        time_columns = self.df_conns.columns[2:].tolist()
        self.time_points = [int(col) for col in time_columns]
        
    def neuronal_death(self,conns):
        '''
        Deletes a random neuron to simulate neuron death.

            Parameters:
                conns: collection
                The collection of all connections in the network.

        '''
        lesion = random.choice(list(self.all_neurons))
        for i in conns:
            #Disconnecting all the synapses where the neuron is a presynaptic neuron.
            if i.source == lesion and i.target <= self.total_neurons:
                nest.Disconnect(i)
            #Disconnecting all the synapses where the neuron is a postsynaptic neuron.
            elif i.target == lesion and i.source <= self.total_neurons:
                nest.Disconnect(i)
           
    def apply_structural_plasticity_rule_disconnection(self,conns):
        '''
        Applies the disconnection structural plasticity rule by calculating the threshold and looping through the connections to 
        see which connections fall under the threshold. It also stores all disconnected synapses in a list.

            Parameters:
                conns: collection
                The collection of all connections in the network.

        '''
        #Threshold is calculated
        self.threshold = self.synapse_weight * self.multiplier
        self.deleted = []
        for i in conns:
            #Ensuring the IDs are of neurons in the grid and then checking if the weights are below the threshold
            if i.source <=self.total_neurons and i.target <= self.total_neurons and i.weight < self.threshold: 
                if random.random() <= 1: #Chance is set to 1 in this study
                    nest.Disconnect(i)
                    self.deleted.append((i.source,i.target))
        
    
    def apply_structural_plasticity_rule_connection(self):
        '''
        Applies the connection structural plasticity rule by randomly connecting two neurons that were not previously connected. 
        It also stores all connected synapses in a list.

        '''
        i = 0
        while True:
            source = nest.NodeCollection([random.choice(list(self.all_neurons))])
            target = nest.NodeCollection([random.choice(list(self.all_neurons))])
            i += 1
            connection = nest.GetConnections(source,target)
            #Checks if this connection already exists
            if len(connection) == 0 and source!= target:
                if random.random() <= 1: #Chance is set to 1 in this study
                    nest.Connect(source,target,syn_spec=self.syn_dict)
                    self.all_connected.append((source,target))
                break
            #The algorithm has 50 chances to pick a pair of neurons that are not already connected. For the size of this network
            #this is sufficient.
            elif i > 50:
                break
    
    
        
    def plot_weight_change(self,source, target):
        '''
        This function plots the weight changes of a synapse between a pair of given neurons using the dataframe collected 
        from the simulation function.

            Parameters:
                source: int
                    The neuron ID of the chosen source neuron.
                target: int
                    The neuron ID of the chosen target neuron.

        '''
        row = self.df_conns[(self.df_conns['source'] == source) & (self.df_conns['target'] == target)]
        if not row.empty:
            row = row.iloc[0]
        else:
            print('No connection found.')
            return
        weight_list = row[2:].tolist()
        label = f"{source} to {target}"
        
        
        mask = np.array(self.voltage_senders) == source
        neuron_times = np.array(self.voltage_times)[mask]
        neuron_voltages = np.array(self.voltages)[mask]
        
        mask = np.array(self.senders) == source
        neuron_spike_times = np.array(self.times)[mask]

        fig,ax1 = plt.subplots()
        ax1.plot(self.time_points,weight_list)
        ax1.set_ylabel('Weight')
        
        ax2 = ax1.twinx()
        ax2.plot(neuron_times,neuron_voltages)
        
        for val in neuron_spike_times:
            ax1.axvline(x=val, color='r', linestyle='--', linewidth=2)
        
        plt.show()
        
        
    def create_heatmap(self,timestep = 0):
        '''
        This function creates a heatmap of the weights between all neurons.

            Parameters:
                timestep: int
                The specific iteration of the training phase the heatmap corresponds to.

        '''
        network_connections = []
        #All pairs neurons are looped through
        for i in self.all_neurons:
            inner_list = []
            for j in self.all_neurons:
                source = nest.NodeCollection([i])
                target = nest.NodeCollection([j])
                connection = nest.GetConnections(source,target)

                #If a pair of neurons does not have a connection it is assigned a weight of 0
                if len(connection) == 0:
                    inner_list.append(0)
                else:
                #If a pair of neurons does have a connection it is assigned its synaptic weight.
                    inner_list.append(connection.weight)

            network_connections.append(inner_list)
        
        self.network_connections = np.array(network_connections)  
        mask = self.network_connections == 0  
        
        #Plotting parameters
        plt.figure(figsize = (20,20))
        plt.rcParams.update({'font.size': 24})
        ax = sns.heatmap(self.network_connections,vmin=250, vmax=350,cmap='viridis'
            annot=False, fmt=".2f",square=True, linewidths=.5, mask = mask)
        ax.invert_yaxis()
        ax.set_xlabel("Target Neuron",fontsize = 28)
        ax.set_ylabel("Source Neuron",fontsize = 28)
        ax.set_xticklabels(self.all_neurons, rotation=90,fontsize = 22)
        ax.set_yticklabels(self.all_neurons, rotation=0,fontsize = 22)
        
        file_path = os.path.join(self.title, f"weight_heatmap_{timestep}_{self.title}.png")
        plt.savefig(file_path)
        plt.show()
        
    def print_metrics(self, timestep = 0):
        '''
        This function prints the metrics of the network at any given iteration.

            Parameters:
                timestep: int
                The specific iteration of the training phase the metrics correspond to.

        '''
        print("Total spikes:", len(self.senders))
        print("Neurons that spiked:", set(self.senders))

    def plot_spikes(self,timestep = 0):
        '''
        This function creates a spike raster plot of the network at any given iteration.

            Parameters:
                timestep: int
                The specific iteration of the training phase the plot corresponds to.

        '''
        plt.figure(figsize=(15, 5))
        plt.scatter(self.times, self.senders, s=5,color='black')
        plt.xlabel("Time (ms)")
        plt.ylabel("Neuron")
        file_path = os.path.join(self.title, f"spike_plot_at_{timestep}_{self.title}.png")
        plt.savefig(file_path)
        plt.show()
        
        
    def plot_voltage(self,timestep = 0):
        '''
        This function creates a voltage trace of the network at any given iteration.

            Parameters:
                timestep: int
                The specific iteration of the training phase the plot corresponds to.

        '''
        nest.voltage_trace.from_device(self.voltmeter)
        file_path = os.path.join(self.title, f"voltage_plot_at_{timestep}_{self.title}.png")
        plt.savefig(file_path)
        plt.show()
        
        
    def stop_synaptic_plasticity(self):
        '''
        This function stops synaptic plasticity by replacing all stdp synapses with static synapses.
        '''
        for conn in nest.GetConnections():
            if conn.synapse_model == 'stdp_synapse':
                source,target,weight = conn.source, conn.target,conn.weight
                nest.Connect([source],[target],syn_spec = {"synapse_model": "static_synapse","weight": weight,"delay": self.synapse_delay})
                #Previous stdp connection is removed
                nest.Disconnect(conn)
        
        
    def test_wave(self,testing_time):
        '''
        This function carries out the testing phase of the model. A new spike generator is made and set to spike 100 ms after 
        the end of the training phase. The spike generator is assigned to the first column of the neuron grid or the last if a 
        reverse wave is being tested. A new voltmeter and spike recorder are assigned to the neurons. The test phase is then 
        simulated and all metrics are calculated. The order of spikes is collected in a dataframe to be used to calculate the 
        wave correlation. A voltage trace and spike raster plot are plotted.

            Parameters:
                testing_time: int
                The total testing time.

        '''
        
        #Test spike generator is created and set to spike
        test_spike_generator= nest.Create('spike_generator')
        test_spike_generator.set(spike_times=[self.time + 100])

        #Spike generator is connected to first or last column depending on wave direction
        for j in range(self.grid_rows):
            if self.reverse == False:
                nest.Connect(test_spike_generator, [self.neuron_grid[j][0]], syn_spec={'weight': self.test_weight})
            else:
                nest.Connect(test_spike_generator, [self.neuron_grid[j][-1]], syn_spec={'weight': self.test_weight})
            
        #Spike voltmeter and spike recorder are created and connected to all neurons
        self.test_voltmeter = nest.Create('voltmeter')
        self.test_spike_recorder = nest.Create("spike_recorder")
        
        nest.Connect(self.layer, self.test_spike_recorder)
        nest.Connect(self.test_voltmeter,self.all_neurons)
        
        #Testing phase commences
        nest.Simulate(testing_time)

        #Collecting membrane potential and spiking metrics
        self.t_ev = nest.GetStatus(self.test_spike_recorder, "events")[0]
        self.test_senders, self.test_times = self.t_ev["senders"], self.t_ev["times"]
        
        self.tv_ev = nest.GetStatus(self.test_voltmeter, "events")[0]
        self.test_voltages, self.test_voltage_times, self.test_voltage_senders = self.tv_ev["V_m"], self.tv_ev["times"], self.tv_ev["senders"]
        
        nest.voltage_trace.from_device(self.test_voltmeter)
        file_path = os.path.join(self.title, f"test_voltage_plot_{self.title}.png")
        plt.savefig(file_path)
        plt.show()       
        
        #Spiking order is recorded for wave correlation metrics
        print("Total spikes:", len(self.test_senders))
        print("Neurons that spiked:", set(self.test_senders))
        self.test_spike_df = pd.DataFrame(self.test_senders, columns=["Neuron ID"])
        self.test_spike_df.to_csv(os.path.join(self.title,f"test_spiking_neurons_{self.title}.csv"), index=False)
        
        #Spike raster is plotted
        plt.figure(figsize=(15, 5))
        plt.scatter(self.test_times, self.test_senders, s=5,color='black')
        plt.xlabel("Time (ms)")
        plt.ylabel("Neuron")
        file_path = os.path.join(self.title, f"spike_plot_test_wave_{self.title}.png")
        plt.savefig(file_path)
        plt.show()
        
    def calculate_wave_correlation(self):
        '''
        This function calculates the wave correlation by obtaining the order of the test spikes from the dataframe collected 
        during the test phase. It then calculates the expected order of spiking for a neuron grid of that size. The wave 
        correlation is calculated by finding the spearmans rank score for the ideal order and the propagated wave order. It 
        calculates the score for 1 - 10 cycles so it can be determined at which iteration the wave degrades.

        '''
        for i in range(1,11):
            neuron_list_full = self.test_spike_df["Neuron ID"].tolist()
            neuron_list = neuron_list_full[self.grid_rows:]

            num_of_cycles = i
            cycle_spikes = (self.grid_rows * self.grid_cols) * num_of_cycles
            full_cycle = neuron_list[:cycle_spikes]

            #Ideal sequence 
            base = [neuron for column in self.column_ids for neuron in column]
            ideal = base * num_of_cycles
            
            rho, p = spearmanr(full_cycle,ideal)
            print(rho)
        return rho
        
    def get_radial_plot(self,node,test=False):
        '''
        This function creates a radial plot of the direct connections of any neuron.

            Parameters:
                node: int
                The neuron that a plot will be generated for.

                test: bool
                This saves the file as either a training or test plot.

        '''
        self.positions = nest.GetPosition(self.layer)
        neurons = list(self.layer)                
        position_to_neuron = {tuple(pos): nid for pos, nid in zip(self.positions, neurons)}

        #Obtains coordinates of neurons directly next to the chosen neuron
        x,y = self.positions[node-1]
        pos = [(x,y-1),(x-1,y-1),(x-1,y),(x-1,y+1),(x,y+1),(x+1,y+1),(x+1,y),(x+1,y-1)]
        nid = []
        min_val = 250

        #Obtains the neurons directly next to the chosen neuron
        for i, j in pos:
            neighbour = position_to_neuron.get((i, j))
            if neighbour is not None:
                nid.append(neighbour.get('global_id'))
            else:
                nid.append(None)

        #Obtains the weight of the connection between the chosen neuron and neighbour
        conn_lookup = {(i.source, i.target): i.weight for i in nest.GetConnections()}

        weight_list = []
        for j in nid:
            if j is None:
                weight_list.append(min_val)
            else:
                weight_list.append(conn_lookup.get((node, j), min_val))
                        

        #Plotting parameters
        N = 8
        theta = np.linspace(0.0, 2 * np.pi, N, endpoint=False)
        width = np.pi / 4.25  

        
        fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
        norm = colors.Normalize(vmin=250, vmax=360)
        cmap = cm.get_cmap('viridis')
        bar_colours = [cmap(norm(w)) for w in weight_list]
        
        
        ax.bar(theta, height=np.array(weight_list) - min_val, width=width, bottom=min_val,color=bar_colours, alpha=0.8)
        ax.set_rlim(min_val, 360)
        ax.tick_params(labelsize=12)

        sm = cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])

        file_path = os.path.join(self.title, f"radial_node_{node}_{test}_{self.title}.png")
        plt.savefig(file_path)
        plt.show()
        
        
    def animate_spikes(self,frame_dt=1):
        '''
        This function animates the test wave propagated.

        '''
        rows, cols = self.neuron_grid.shape
        self.neuron_id_to_pos = {neuron_id: (i, j) for i in range(rows) for j in range(cols) for neuron_id in [self.neuron_grid[i, j]]}
        t_min, t_max = self.test_times.min(), self.test_times.max()
        frames = np.arange(t_min, t_max, frame_dt)

        fig, ax = plt.subplots(figsize=(6, 6))
        scatterp = ax.scatter([], [], s=100, c='red')

        ax.set_xlim(-0.5, cols - 0.5)
        ax.set_ylim(-0.5, rows - 0.5)
        ax.set_xticks(range(cols))
        ax.set_yticks(range(rows))
        ax.set_title("")

        def update(frame_time):
            mask = (self.test_times >= frame_time) & (self.test_times < frame_time + frame_dt)
            active_ids = self.test_senders[mask]
            active_pos = [self.neuron_id_to_pos[nid] for nid in active_ids if nid in self.neuron_id_to_pos]

            if active_pos:
                x, y = zip(*[(col, rows - 1 - row) for row, col in active_pos])
            else:
                x, y = [], []

            scatterp.set_offsets(np.c_[x, y])
            return scatterp,
    
        ani = animation.FuncAnimation(fig, update, frames=frames, interval=9999, blit=True)
        plt.show()
        
        file_path = os.path.join(self.title, f"snn_activity_{self.title}.gif")
        ani.save(file_path, writer='pillow', fps=2)
        
        
    def count_connections(self):
        '''
        This function calculates the number of different connections in the network at any given time.

        '''
        conns = nest.GetConnections()
        forward, backward, intracolumn = 0, 0, 0

        for conn in conns:
            source, target = conn.source, conn.target
            if source not in self.neuron_ids or target not in self.neuron_ids:
                continue

            for i in range(self.grid_rows):
                if source in self.column_ids[i]:
                    for j in range(self.grid_rows):
                        if target in self.column_ids[j]:
                            if i == j:
                                intracolumn += 1
                            elif i < j:
                                forward += 1
                            elif i > j:
                                backward += 1
                            break
                    break

        print(f"Forward: {forward}, Backward: {backward}, Intracolumn: {intracolumn}")
        return forward,backward,intracolumn
        
    def calculate_in_degree_and_strength(self):
        '''
        This function calculates the in/out degree and in/out strength of all neurons in the network at any given time.

        '''
        conns = nest.GetConnections()

        in_degree = {nid: 0 for nid in self.neuron_ids}
        in_strength = {nid: 0.0 for nid in self.neuron_ids}
        out_degree = {nid: 0 for nid in self.neuron_ids}
        out_strength = {nid: 0.0 for nid in self.neuron_ids}

        for conn in conns:
            source, target, w = conn.source, conn.target, conn.weight
            if target in self.neuron_ids and source in self.neuron_ids:
                in_degree[target] += 1
                in_strength[target] += w
            if target in self.neuron_ids and source in self.neuron_ids:
                out_degree[source] += 1
                out_strength[source] += w

        return in_degree, in_strength, out_degree, out_strength
    


## Define Network Parameters

In [None]:
params = {
    'grid_rows': 5,
    'grid_cols': 5,
    'title': f"structural_plasticity_model",
    'structural_plasticity': True,
    'reverse':False,
    'synapse_weight': 300,
    'alpha': 0.8,
    'ld': 0.01,
    'prob': 1,
    'radius': 2.5,
    'spike_delay': 20,
    'spike_weight': 3000,
    'synapse_delay':1.5,
    'multiplier': 0.9983,
    'test_weight': 7000
}

For each size grid the following test weights and multipliers are needed:\
Test Weight = {55:7000,56:15475,57:17200} \
Multiplier = {55:0.9983,56:0.99,57:0.99}\
Training Time = {55:1000,56:1200,57:1300}

## Running the Network

To run the network all that is needed to do is simulate the training phase, and then the testing phase after it. To stop synaptic plasticity run the function between the training and testing phase.

In [None]:
def run_network():
    #Build the network 
    network = buildNetwork(**params)
    #Simulate training
    network.simulate(1000)
    
    #network.stop_synaptic_plasticity()
    
    #Simulate testing
    network.test_wave(1000)
    return network

In [None]:
network = run_network()

## Metrics

In [None]:
#Wave correlation
network.calculate_wave_correlation()

#Count forward and backward connections, can also plot before running training to see initial connectivity.
network.count_connections()

#Calculate in and out strength, can also plot before running training to see intial connectivity.
network.calculate_in_degree_and_strength()

## Visualisations

Most visualisations are automatically generated during the simulation.

In [None]:
#Animation 
try:
    network.animate_spikes()
except Exception as e:
    print(f"Error: {e}")

In [None]:
#Radial Plot
network.get_radial_plot(6) #Enter neuron to plot, can also plot before running test to see training effects.

#Plot weight change for forward and backward connections
network.plot_weight_change(1,2) # Enter connection to plot
network.plot_weight_change(2,1)