diff --git a/CHANGES.md b/CHANGES.md index 46f8c76a5..798e7fe57 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,26 @@ +# Version 0.8.0 + +- Optimized implementation of raste and spikeHist plotting using Pandas + +- Added option for filename to saveData() + +- Removed pop cellModelClass when saving + +- Removed cell h object keys when saving + +- Added support for evolutionary algorithm optimization (via Inspyred) and usage example + +- cfg.popAvgRates now accepts a time range to calculate rates (e.g. to discard initial period) + +- Fixed bug initalizing batch 'mpi_bulletin' and batch tutorial example + +- Fixed bug: removed '\_labelid' from netParams when saving + +- Fixed bug: made self.scaleConnWeightModels False when not used (avoids saving weird dict in json) + +- Fixed bug in Pickle file encoding so works in Python3 + + # Version 0.7.9 - Extended metadata structure to interact with NetPyNE-UI @@ -34,8 +57,6 @@ - Fixed bug: delete sections after import cell only if section exists -- Fixed bug initalizing batch 'mpi_bulletin' and batch tutorial example - # Version 0.7.8 diff --git a/doc/source/code/tut1.py b/doc/source/code/tut1.py index b5030ca0f..25aef268d 100644 --- a/doc/source/code/tut1.py +++ b/doc/source/code/tut1.py @@ -1,5 +1,3 @@ import HHTut from netpyne import sim sim.createSimulateAnalyze(netParams = HHTut.netParams, simConfig = HHTut.simConfig) - -# import pylab; pylab.show() # this line is only necessary in certain systems where figures appear empty diff --git a/doc/source/index.rst b/doc/source/index.rst index ac8193b7f..bd8786eb1 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -8,7 +8,7 @@ Welcome to NetPyNE's documentation! NetPyNE is a python package to facilitate the development, parallel simulation and analysis of biological neuronal networks using the NEURON simulator. -Check out our new `NetPyNE GUI teaser video `_! The GUI will be released July 2018, ready for our `Tutorial at CNS18 `_! +We have also released an alpha version of the NetPyNE GUI -- `see video here `_ (open in new tab)! See a `step-by-step tutorial `_! Join our `NetPyNE mailing list `_ to receive updates on version releases and other major announcements. diff --git a/doc/source/install.rst b/doc/source/install.rst index c35744643..133f923d6 100644 --- a/doc/source/install.rst +++ b/doc/source/install.rst @@ -3,10 +3,11 @@ Installation ======================================= + Requirements ------------ -The NetPyNE package requires Python 2.7 (www.python.org). +The NetPyNE package requires Python 2 or 3 (both are supported) (download from www.python.org). Additionally, running parallelized simulations of networks requires the NEURON simulator with python and MPI support. See NEURON's `installation instructions `_ and `documentation `_ @@ -20,16 +21,27 @@ Note: It is possible to use the NetPyNE package without NEURON, to convert model Install via pip (latest released version) ----------------------------------------- +Python 2 +^^^^^^^^^^^^ + To install the the package run ``pip install netpyne`` (Linux or Mac OS) or ``python -m pip install netpyne`` (Windows) -To upgrade to a new version run ``pip install netpyne -U`` (Linux or Mac OS) or ``python -m pip install -U pip`` (Windows) +To upgrade to a new version run ``pip install netpyne -U`` (Linux or Mac OS) or ``python -m pip install -U netpyne`` (Windows) -If you need to install ``pip`` go to `this link `_ +Python 3 +^^^^^^^^^^^^ -The NetPyNE package source files, as well as example models, are available via github at: https://github.com/Neurosim-lab/netpyne +To install the the package run ``pip3 install netpyne_py3`` (Linux or Mac OS) or ``python -m pip3 install netpyne_py3`` (Windows) + +To upgrade to a new version run ``pip3 install netpyne_py3 -U`` (Linux or Mac OS) or ``python -m pip3 install -U netpyne_py3`` (Windows) + + +pip +^^^^^^^^^^ +If you need to install ``pip`` go to `this link `_ -Install via pip (development version) +Install via pip (development version; only for Python 2) -------------------------------------- This will install the version in the github "development" branch -- it will include some of the latest enhancements and bug fixes, but could also include temporary bugs: @@ -39,4 +51,13 @@ This will install the version in the github "development" branch -- it will incl 3) git checkout development 4) pip install -e . -pip will add a symlink in the default python packages folder to the cloned netpyne folder (so you don't need to modify PYTHONPATH). If new changes are available just need to pull from cloned netpyne repo. \ No newline at end of file +pip will add a symlink in the default python packages folder to the cloned netpyne folder (so you don't need to modify PYTHONPATH). If new changes are available just need to pull from cloned netpyne repo. + +Install NetPyNE GUI (alpha version; only for Python 2) +-------------------------------------- + +You can install the tool via `pip `_ (requires `NEURON `_ ), or using our pre-packaged `Docker `_ or `Virtual Machine `_. + +Github +------------ +The NetPyNE package source files, as well as example models, are available via github at: https://github.com/Neurosim-lab/netpyne diff --git a/examples/batchCell/batch.py b/examples/batchCell/batch.py index 9328914a5..2e860de3b 100644 --- a/examples/batchCell/batch.py +++ b/examples/batchCell/batch.py @@ -22,7 +22,7 @@ def runBatch(b, label): b.batchLabel = label b.saveFolder = 'data/'+b.batchLabel b.method = 'grid' - b.runCfg = {'type': 'mpi', + b.runCfg = {'type': 'mpi_bulletin', 'script': 'init.py', 'skip': True} diff --git a/examples/batchCell/cells/SPI6.py b/examples/batchCell/cells/SPI6.py index ea33c84cc..f5233de8d 100644 --- a/examples/batchCell/cells/SPI6.py +++ b/examples/batchCell/cells/SPI6.py @@ -3,6 +3,8 @@ from neuron import h from math import exp,log +h.load_file('stdrun.hoc') + Vrest = -88.5366550238 h.v_init = -75.0413649414 h.celsius = 34.0 # for in vitro opt diff --git a/netpyne/analysis.py b/netpyne/analysis.py index 2dcefc6cd..a263d9325 100644 --- a/netpyne/analysis.py +++ b/netpyne/analysis.py @@ -41,14 +41,15 @@ def exception(function): """ @functools.wraps(function) def wrapper(*args, **kwargs): + import sys try: return function(*args, **kwargs) except Exception as e: - # print + # print err = "There was an exception in %s():"%(function.__name__) - print(("%s \n %s"%(err,e))) + print(("%s \n %s \n%s"%(err,e,sys.exc_info()))) return -1 - + return wrapper @@ -71,20 +72,20 @@ def plotData (): # Print timings if sim.cfg.timing: - + sim.timing('stop', 'plotTime') print((' Done; plotting time = %0.2f s' % sim.timingData['plotTime'])) - + sim.timing('stop', 'totalTime') sumTime = sum([t for k,t in sim.timingData.items() if k not in ['totalTime']]) - if sim.timingData['totalTime'] <= 1.2*sumTime: # Print total time (only if makes sense) + if sim.timingData['totalTime'] <= 1.2*sumTime: # Print total time (only if makes sense) print(('\nTotal time = %0.2f s' % sim.timingData['totalTime'])) ###################################################################################################################################################### ## Round to n sig figures ###################################################################################################################################################### def _roundFigures(x, n): - return round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1)) + return round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1)) ###################################################################################################################################################### ## show figure @@ -116,7 +117,7 @@ def _saveFigData(figData, fileName=None, type=''): print(('Saving figure data as %s ... ' % (fileName))) with open(fileName, 'w') as fileObj: json.dump(figData, fileObj) - else: + else: print('File extension to save figure data not recognized') @@ -184,6 +185,25 @@ def _smooth1d(x,window_len=11,window='hanning'): return y[(window_len/2-1):-(window_len/2)] +###################################################################################################################################################### +## Get subset of spkt, spkid based on a timeRange and cellGids list; ~10x speedup over list iterate +###################################################################################################################################################### +def getSpktSpkid(cellGids=[], timeRange=None, allCells=False): + '''return spike ids and times; with allCells=True just need to identify slice of time so can omit cellGids''' + import pandas as pd + from . import sim + df = pd.DataFrame(pd.lib.to_object_array([sim.allSimData['spkt'], sim.allSimData['spkid']]).transpose(), columns=['spkt', 'spkid']) + if timeRange: + min, max = [int(df['spkt'].searchsorted(timeRange[i])) for i in range(2)] # binary search faster than query + else: # timeRange None or empty list means all times + min, max = 0, len(df) + if len(cellGids)==0 or allCells: # get all by either using flag or giving empty list -- can get rid of the flag + sel = df[min:max] + else: + sel = df[min:max].query('spkid in @cellGids') + return sel, sel['spkt'].tolist(), sel['spkid'].tolist() # will want to return sel as well for further sorting + + ###################################################################################################################################################### ## Get subset of cells and netstims indicated by include list ###################################################################################################################################################### @@ -196,34 +216,34 @@ def getCellsInclude(include): cells = [] netStimLabels = [] for condition in include: - if condition == 'all': # all cells + Netstims + if condition == 'all': # all cells + Netstims cellGids = [c['gid'] for c in allCells] cells = list(allCells) netStimLabels = list(allNetStimLabels) return cells, cellGids, netStimLabels - elif condition == 'allCells': # all cells + elif condition == 'allCells': # all cells cellGids = [c['gid'] for c in allCells] cells = list(allCells) - elif condition == 'allNetStims': # all cells + Netstims + elif condition == 'allNetStims': # all cells + Netstims netStimLabels = list(allNetStimLabels) - elif isinstance(condition, int): # cell gid + elif isinstance(condition, int): # cell gid cellGids.append(condition) - + elif isinstance(condition, str): # entire pop if condition in allNetStimLabels: netStimLabels.append(condition) else: cellGids.extend([c['gid'] for c in allCells if c['tags']['pop']==condition]) - + # subset of a pop with relative indices # when load from json gets converted to list (added as exception) - elif (isinstance(condition, (list,tuple)) - and len(condition)==2 - and isinstance(condition[0], str) - and isinstance(condition[1], (list,int))): + elif (isinstance(condition, (list,tuple)) + and len(condition)==2 + and isinstance(condition[0], str) + and isinstance(condition[1], (list,int))): cellsPop = [c['gid'] for c in allCells if c['tags']['pop']==condition[0]] if isinstance(condition[1], list): cellGids.extend([gid for i,gid in enumerate(cellsPop) if i in condition[1]]) @@ -232,9 +252,9 @@ def getCellsInclude(include): elif isinstance(condition, (list,tuple)): # subset for subcond in condition: - if isinstance(subcond, int): # cell gid + if isinstance(subcond, int): # cell gid cellGids.append(subcond) - + elif isinstance(subcond, str): # entire pop if subcond in allNetStimLabels: netStimLabels.append(subcond) @@ -256,21 +276,21 @@ def getCellsIncludeTags(include, tags, tagsFormat=None): cellGids = [] # using list with indices - if tagsFormat or 'format' in allCells: + if tagsFormat or 'format' in allCells: if not tagsFormat: tagsFormat = allCells.pop('format') popIndex = tagsFormat.index('pop') for condition in include: - if condition in ['all', 'allCells']: # all cells + if condition in ['all', 'allCells']: # all cells cellGids = list(allCells.keys()) return cellGids - elif isinstance(condition, int): # cell gid + elif isinstance(condition, int): # cell gid cellGids.append(condition) - + elif isinstance(condition, str): # entire pop cellGids.extend([gid for gid,c in allCells.items() if c[popIndex]==condition]) - + elif isinstance(condition, tuple): # subset of a pop with relative indices cellsPop = [gid for gid,c in allCells.items() if c[popIndex]==condition[0]] if isinstance(condition[1], list): @@ -280,18 +300,18 @@ def getCellsIncludeTags(include, tags, tagsFormat=None): # using dict with keys else: - + for condition in include: - if condition in ['all', 'allCells']: # all cells + if condition in ['all', 'allCells']: # all cells cellGids = list(allCells.keys()) return cellGids - elif isinstance(condition, int): # cell gid + elif isinstance(condition, int): # cell gid cellGids.append(condition) - + elif isinstance(condition, str): # entire pop cellGids.extend([gid for gid,c in allCells.items() if c['pop']==condition]) - + elif isinstance(condition, tuple): # subset of a pop with relative indices cellsPop = [gid for gid,c in allCells.items() if c['pop']==condition[0]] if isinstance(condition[1], list): @@ -310,12 +330,12 @@ def getCellsIncludeTags(include, tags, tagsFormat=None): def syncMeasure (): from . import sim - t0=-1 - width=1 + t0=-1 + width=1 cnt=0 for spkt in sim.allSimData['spkt']: - if (spkt>=t0+width): - t0=spkt + if (spkt>=t0+width): + t0=spkt cnt+=1 return 1-cnt/(sim.cfg.duration/width) @@ -324,10 +344,10 @@ def syncMeasure (): ## Calculate avg and peak rate of different subsets of cells for specific time period ###################################################################################################################################################### @exception -def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = None): - ''' +def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = None): + ''' Calculate avg and peak rate of different subsets of cells for specific time period - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. Note: one line per item, not grouped (default: ['allCells', 'eachPop']) - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - peakBin (int): Histogram bin size used to calculate peak firing rate; if None, peak rate not calculated (default: 5) @@ -339,7 +359,7 @@ def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = N print('Calculating avg and peak firing rates ...') # Replace 'eachPop' with list of pops - if 'eachPop' in include: + if 'eachPop' in include: include.remove('eachPop') for pop in sim.net.allPops: include.append(pop) @@ -360,7 +380,7 @@ def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = N spkinds,spkts = list(zip(*[(spkgid,spkt) for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids])) except: spkinds,spkts = [],[] - else: + else: spkinds,spkts = [],[] # Add NetStim spikes @@ -372,7 +392,7 @@ def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = N for stimLabel,stimSpks in stims.items() for spk in stimSpks if stimLabel == netStimLabel] if len(netStimSpks) > 0: lastInd = max(spkinds) if len(spkinds)>0 else 0 - spktsNew = netStimSpks + spktsNew = netStimSpks spkindsNew = [lastInd+1+i for i in range(len(netStimSpks))] spkts.extend(spktsNew) spkinds.extend(spkindsNew) @@ -381,7 +401,7 @@ def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = N if peakBin: histo = np.histogram(spkts, bins = np.arange(timeRange[0], timeRange[1], peakBin)) histoT = histo[1][:-1]+peakBin/2 - histoCount = histo[0] + histoCount = histo[0] histData.append(histoCount) @@ -395,20 +415,20 @@ def calculateRate (include = ['allCells', 'eachPop'], peakBin = 5, timeRange = N ###################################################################################################################################################### -## Plot avg and peak rates at different time periods +## Plot avg and peak rates at different time periods ###################################################################################################################################################### @exception -def plotRates (include =['allCells', 'eachPop'], peakBin = 5, timeRanges = None, timeRangeLabels = None, colors = None, figSize = ((5,5)), saveData = None, +def plotRates (include =['allCells', 'eachPop'], peakBin = 5, timeRanges = None, timeRangeLabels = None, colors = None, figSize = ((5,5)), saveData = None, ylim = None, saveFig = None, showFig = True): - ''' + ''' Calculate avg and peak rate of different subsets of cells for specific time period - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. Note: one line per item, not grouped (default: ['allCells', 'eachPop']) - timeRanges ([[start1:stop1], [start2:stop2]]): List of time range of spikes shown; if None shows all (default: None) - timeRangeLabels (['preStim', 'postStim']): List of labels for each time range period (default: None) - peakBin (int): Histogram bin size used to calculate peak firing rate; if None, peak rate not calculated (default: 5) - figSize ((width, height)): Size of figure (default: (10,8)) - - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; + - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; if set to True uses filename from simConfig (default: None) - saveFig (None|True|'fileName'): File name where to save the figure (default: None) if set to True uses filename from simConfig (default: None) @@ -451,14 +471,14 @@ def plotRates (include =['allCells', 'eachPop'], peakBin = 5, timeRanges = None, pass # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'avgRates.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() # peak @@ -479,19 +499,19 @@ def plotRates (include =['allCells', 'eachPop'], peakBin = 5, timeRanges = None, pass # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'peakRates.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() else: fig1, fig2 = None, None - + # save figure data if saveData: figData = {'includeList': includeList, 'timeRanges': timeRanges, 'avgs': avgs, 'peaks': peaks} @@ -504,19 +524,19 @@ def plotRates (include =['allCells', 'eachPop'], peakBin = 5, timeRanges = None, ###################################################################################################################################################### -## Plot sync at different time periods +## Plot sync at different time periods ###################################################################################################################################################### @exception -def plotSyncs (include =['allCells', 'eachPop'], timeRanges = None, timeRangeLabels = None, colors = None, figSize = ((5,5)), saveData = None, +def plotSyncs (include =['allCells', 'eachPop'], timeRanges = None, timeRangeLabels = None, colors = None, figSize = ((5,5)), saveData = None, saveFig = None, showFig = True): - ''' + ''' Calculate avg and peak rate of different subsets of cells for specific time period - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. Note: one line per item, not grouped (default: ['allCells', 'eachPop']) - timeRanges ([[start1:stop1], [start2:stop2]]): List of time range of spikes shown; if None shows all (default: None) - timeRangeLabels (['preStim', 'postStim']): List of labels for each time range period (default: None) - figSize ((width, height)): Size of figure (default: (10,8)) - - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; + - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; if set to True uses filename from simConfig (default: None) - saveFig (None|True|'fileName'): File name where to save the figure (default: None) if set to True uses filename from simConfig (default: None) @@ -553,14 +573,14 @@ def plotSyncs (include =['allCells', 'eachPop'], timeRanges = None, timeRangeLab ax1.legend(include) # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'sync.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() # save figure data @@ -568,7 +588,7 @@ def plotSyncs (include =['allCells', 'eachPop'], timeRanges = None, timeRangeLab figData = {'includeList': includeList, 'timeRanges': timeRanges, 'syncs': syncs} _saveFigData(figData, saveData, 'raster') - + return fig1, syncs @@ -576,14 +596,14 @@ def plotSyncs (include =['allCells', 'eachPop'], timeRanges = None, timeRangeLab ###################################################################################################################################################### -## Raster plot +## Raster plot ###################################################################################################################################################### -@exception +exception def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, orderBy = 'gid', orderInverse = False, labels = 'legend', popRates = False, - spikeHist = None, spikeHistBin = 5, syncLines = False, lw = 2, marker = '|', markerSize=5, popColors = None, figSize = (10,8), dpi = 100, saveData = None, saveFig = None, - showFig = True): - ''' - Raster plot of network cells + spikeHist = None, spikeHistBin = 5, syncLines = False, lw = 2, marker = '|', markerSize=5, popColors = None, figSize = (10,8), dpi = 100, saveData = None, saveFig = None, + showFig = True): + ''' + Raster plot of network cells - include (['all',|'allCells',|'allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Cells to include (default: 'allCells') - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - maxSpikes (int): maximum number of spikes that will be plotted (default: 1e8) @@ -599,7 +619,7 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order - popColors (dict): Dictionary with color (value) used for each population (key) (default: None) - figSize ((width, height)): Size of figure (default: (10,8)) - dpi (int): Dots per inch to save fig (default: 100) - - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; + - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; if set to True uses filename from simConfig (default: None) - saveFig (None|True|'fileName'): File name where to save the figure (default: None) if set to True uses filename from simConfig (default: None) @@ -607,15 +627,31 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order - Returns figure handle ''' - from . import sim + import pandas as pd print('Plotting raster...') # Select cells to include cells, cellGids, netStimLabels = getCellsInclude(include) - selectedPops = [cell['tags']['pop'] for cell in cells] - popLabels = [pop for pop in sim.net.allPops if pop in selectedPops] # preserves original ordering + + df = pd.DataFrame.from_records(cells) + df = pd.concat([df.drop('tags', axis=1), pd.DataFrame.from_records(df['tags'].tolist())], axis=1) + + keep = ['pop', 'gid', 'conns'] + + if isinstance(orderBy, str) and orderBy not in cells[0]['tags']: # if orderBy property doesn't exist or is not numeric, use gid + orderBy = 'gid' + elif isinstance(orderBy, str) and not isinstance(cells[0]['tags'][orderBy], Number): + orderBy = 'gid' + + if isinstance(orderBy, list): + keep = keep + list(set(orderBy) - set(keep)) + elif orderBy not in keep: + keep.append(orderBy) + df = df[keep] + + popLabels = [pop for pop in sim.net.allPops if pop in df['pop'].unique()] #preserves original ordering if netStimLabels: popLabels.append('NetStims') popColorsTmp = {popLabel: colorList[ipop%len(colorList)] for ipop,popLabel in enumerate(popLabels)} # dict with color for each pop if popColors: popColorsTmp.update(popColors) @@ -623,64 +659,45 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order if len(cellGids) > 0: gidColors = {cell['gid']: popColors[cell['tags']['pop']] for cell in cells} # dict with color for each gid try: - spkgids,spkts = list(zip(*[(spkgid,spkt) for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids])) + sel, spkts,spkgids = getSpktSpkid(cellGids=cellGids, timeRange=timeRange, allCells=(include == ['allCells'])) except: + import sys + print((sys.exc_info())) spkgids, spkts = [], [] - spkgidColors = [gidColors[spkgid] for spkgid in spkgids] + sel = pd.DataFrame(columns=['spkt', 'spkid']) + sel['spkgidColor'] = sel['spkid'].map(gidColors) + df['gidColor'] = df['pop'].map(popColors) + df.set_index('gid', inplace=True) # Order by - if len(cellGids) > 0: - if isinstance(orderBy, str) and orderBy not in cells[0]['tags']: # if orderBy property doesn't exist or is not numeric, use gid - orderBy = 'gid' - elif isinstance(orderBy, str) and not isinstance(cells[0]['tags'][orderBy], Number): - orderBy = 'gid' - ylabelText = 'Cells (ordered by %s)'%(orderBy) - - if orderBy == 'gid': - yorder = [cell[orderBy] for cell in cells] - #sortedGids = {gid:i for i,(y,gid) in enumerate(sorted(zip(yorder,cellGids)))} - sortedGids = [gid for y,gid in sorted(zip(yorder,cellGids))] - - elif isinstance(orderBy, str): - yorder = [cell['tags'][orderBy] for cell in cells] - #sortedGids = {gid:i for i,(y,gid) in enumerate(sorted(zip(yorder,cellGids)))} - sortedGids = [gid for y,gid in sorted(zip(yorder,cellGids))] - elif isinstance(orderBy, list) and len(orderBy) == 2: - yorders = [[popLabels.index(cell['tags'][orderElem]) if orderElem=='pop' else cell['tags'][orderElem] - for cell in cells] for orderElem in orderBy] - #sortedGids = {gid:i for i, in enumerate(sorted(zip(yorders[0], yorders[1], cellGids)))} - sortedGids = [gid for (y0, y1, gid) in sorted(zip(yorders[0], yorders[1], cellGids))] - - #sortedGids = {gid:i for i, (y, gid) in enumerate(sorted(zip(yorder, cellGids)))} - spkinds = [sortedGids.index(gid) for gid in spkgids] + if len(df) > 0: + ylabelText = 'Cells (ordered by %s)'%(orderBy) + df = df.sort_values(by=orderBy) + sel['spkind'] = sel['spkid'].apply(df.index.get_loc) else: - spkts = [] - spkinds = [] - spkgidColors = [] + sel = pd.DataFrame(columns=['spkt', 'spkid', 'spkind']) ylabelText = '' - # Add NetStim spikes - spkts,spkgidColors = list(spkts), list(spkgidColors) - numCellSpks = len(spkts) + numCellSpks = len(sel) numNetStims = 0 for netStimLabel in netStimLabels: netStimSpks = [spk for cell,stims in sim.allSimData['stims'].items() \ for stimLabel,stimSpks in stims.items() for spk in stimSpks if stimLabel == netStimLabel] if len(netStimSpks) > 0: - lastInd = max(spkinds) if len(spkinds)>0 else 0 - spktsNew = netStimSpks + # lastInd = max(spkinds) if len(spkinds)>0 else 0 + lastInd = sel['spkind'].max() if len(sel['spkind']) > 0 else 0 + spktsNew = netStimSpks spkindsNew = [lastInd+1+i for i in range(len(netStimSpks))] - spkts.extend(spktsNew) - spkinds.extend(spkindsNew) - for i in range(len(spktsNew)): - spkgidColors.append(popColors['NetStims']) + ns = pd.DataFrame(list(zip(spktsNew, spkindsNew)), columns=['spkt', 'spkind']) + ns['spkgidColor'] = popColors['netStims'] + sel = pd.concat([sel, ns]) numNetStims += 1 else: pass #print netStimLabel+' produced no spikes' - if len(cellGids)>0 and numNetStims: + if len(cellGids)>0 and numNetStims: ylabelText = ylabelText + ' and NetStims (at the end)' elif numNetStims: ylabelText = ylabelText + 'NetStims' @@ -690,76 +707,74 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order return None # Time Range + #### Time range is already queried in getSpktSpkid??? #### if timeRange == [0,sim.cfg.duration]: pass elif timeRange is None: timeRange = [0,sim.cfg.duration] else: - spkinds,spkts,spkgidColors = list(zip(*[(spkind,spkt,spkgidColor) for spkind,spkt,spkgidColor in zip(spkinds,spkts,spkgidColors) - if timeRange[0] <= spkt <= timeRange[1]])) + sel = sel.query('spkt >= @timeRange[0] and spkt <= @timeRange[1]') + # Limit to maxSpikes - if (len(spkts)>maxSpikes): - print((' Showing only the first %i out of %i spikes' % (maxSpikes, len(spkts)))) # Limit num of spikes + if (len(sel)>maxSpikes): + print((' Showing only the first %i out of %i spikes' % (maxSpikes, len(sel)))) # Limit num of spikes if numNetStims: # sort first if have netStims - spkts, spkinds, spkgidColors = list(zip(*sorted(zip(spkts, spkinds, spkgidColors)))) - spkts = spkts[:maxSpikes] - spkinds = spkinds[:maxSpikes] - spkgidColors = spkgidColors[:maxSpikes] - timeRange[1] = max(spkts) + sel = sel.sort_values(by='spkt') + sel = sel.iloc[:maxSpikes] + timeRange[1] = sel['spkt'].max() + + # Calculate spike histogram - # Calculate spike histogram if spikeHist: - histo = np.histogram(spkts, bins = np.arange(timeRange[0], timeRange[1], spikeHistBin)) + histo = np.histogram(sel['spkt'].tolist(), bins = np.arange(timeRange[0], timeRange[1], spikeHistBin)) histoT = histo[1][:-1]+spikeHistBin/2 histoCount = histo[0] - # Plot spikes fig,ax1 = plt.subplots(figsize=figSize) fontsiz = 12 - + if spikeHist == 'subplot': gs = gridspec.GridSpec(2, 1,height_ratios=[2,1]) ax1=plt.subplot(gs[0]) - - ax1.scatter(spkts, spkinds, lw=lw, s=markerSize, marker=marker, color = spkgidColors) # Create raster + sel['spkt'] = sel['spkt'].apply(pd.to_numeric) + sel.plot.scatter(ax=ax1, x='spkt', y='spkind', lw=lw, s=markerSize, marker=marker, c=sel['spkgidColor'].tolist()) # Create raster ax1.set_xlim(timeRange) - + # Plot stats - gidPops = [cell['tags']['pop'] for cell in cells] + gidPops = df['pop'].tolist() popNumCells = [float(gidPops.count(pop)) for pop in popLabels] if numCellSpks else [0] * len(popLabels) - totalSpikes = len(spkts) - totalConnections = sum([len(cell['conns']) for cell in cells]) - numCells = len(cells) - firingRate = float(totalSpikes)/(numCells+numNetStims)/(timeRange[1]-timeRange[0])*1e3 if totalSpikes>0 else 0 # Calculate firing rate + totalSpikes = len(sel) + totalConnections = sum([len(conns) for conns in df['conns']]) + numCells = len(cells) + firingRate = float(totalSpikes)/(numCells+numNetStims)/(timeRange[1]-timeRange[0])*1e3 if totalSpikes>0 else 0 # Calculate firing rate connsPerCell = totalConnections/float(numCells) if numCells>0 else 0 # Calculate the number of connections per cell - + if popRates: avgRates = {} - tsecs = (timeRange[1]-timeRange[0])/1e3 + tsecs = (timeRange[1]-timeRange[0])/1e3 for i,(pop, popNum) in enumerate(zip(popLabels, popNumCells)): if numCells > 0 and pop != 'NetStims': if numCellSpks == 0: avgRates[pop] = 0 else: - avgRates[pop] = len([spkid for spkid in spkinds[:numCellSpks-1] if sim.net.allCells[sortedGids[int(spkid)]]['tags']['pop']==pop])/popNum/tsecs + avgRates[pop] = len([spkid for spkid in sel['spkind'].iloc[:numCellSpks-1] if df['pop'].iloc[int(spkid)]==pop])/popNum/tsecs if numNetStims: popNumCells[-1] = numNetStims - avgRates['NetStims'] = len([spkid for spkid in spkinds[numCellSpks:]])/numNetStims/tsecs + avgRates['NetStims'] = len([spkid for spkid in sel['spkind'].iloc[numCellSpks:]])/numNetStims/tsecs - # Plot synchrony lines - if syncLines: - for spkt in spkts: + # Plot synchrony lines + if syncLines: + for spkt in sel['spkt'].tolist(): ax1.plot((spkt, spkt), (0, len(cells)+numNetStims), 'r-', linewidth=0.1) plt.title('cells=%i syns/cell=%0.1f rate=%0.1f Hz sync=%0.2f' % (numCells,connsPerCell,firingRate,syncMeasure()), fontsize=fontsiz) else: plt.title('cells=%i syns/cell=%0.1f rate=%0.1f Hz' % (numCells,connsPerCell,firingRate), fontsize=fontsiz) - # Axis ax1.set_xlabel('Time (ms)', fontsize=fontsiz) ax1.set_ylabel(ylabelText, fontsize=fontsiz) ax1.set_xlim(timeRange) - ax1.set_ylim(-1, len(cells)+numNetStims+1) + ax1.set_ylim(-1, len(cells)+numNetStims+1) # Add legend if popRates: @@ -773,7 +788,7 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order maxLabelLen = max([len(l) for l in popLabels]) rightOffset = 0.85 if popRates else 0.9 plt.subplots_adjust(right=(rightOffset-0.012*maxLabelLen)) - + elif labels == 'overlay': ax = plt.gca() tx = 1.01 @@ -794,6 +809,7 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order maxLabelLen = min(6, max([len(l) for l in labels])) plt.subplots_adjust(right=(0.95-0.011*maxLabelLen)) + # Plot spike hist if spikeHist == 'overlay': ax2 = ax1.twinx() @@ -811,36 +827,36 @@ def plotRaster (include = ['allCells'], timeRange = None, maxSpikes = 1e8, order # save figure data if saveData: - figData = {'spkTimes': spkts, 'spkInds': spkinds, 'spkColors': spkgidColors, 'cellGids': cellGids, 'sortedGids': sortedGids, 'numNetStims': numNetStims, + figData = {'spkTimes': sel['spkt'].tolist(), 'spkInds': sel['spkind'].tolist(), 'spkColors': sel['spkgidColor'].tolist(), 'cellGids': cellGids, 'sortedGids': df.index.tolist(), 'numNetStims': numNetStims, 'include': include, 'timeRange': timeRange, 'maxSpikes': maxSpikes, 'orderBy': orderBy, 'orderInverse': orderInverse, 'spikeHist': spikeHist, 'syncLines': syncLines} _saveFigData(figData, saveData, 'raster') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'raster.png' plt.savefig(filename, dpi=dpi) - # show fig + # show fig if showFig: _showFigure() - return fig, {'include': include, 'spkts': spkts, 'spkinds': spkinds, 'timeRange': timeRange} + return fig, {} ###################################################################################################################################################### ## Plot spike histogram ###################################################################################################################################################### @exception -def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize = 5, overlay=True, graphType='line', yaxis = 'rate', - popColors = [], norm = False, dpi = 100, figSize = (10,8), smooth=None, filtFreq = False, filtOrder=3, axis = 'on', saveData = None, - saveFig = None, showFig = True, **kwargs): - ''' +def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize = 5, overlay=True, graphType='line', yaxis = 'rate', + popColors = [], norm = False, dpi = 100, figSize = (10,8), smooth=None, filtFreq = False, filtOrder=3, axis = 'on', saveData = None, + saveFig = None, showFig = True, **kwargs): + ''' Plot spike histogram - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. Note: one line per item, not grouped (default: ['allCells', 'eachPop']) - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - binSize (int): Size in ms of each bin (default: 5) @@ -863,17 +879,17 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize print('Plotting spike histogram...') # Replace 'eachPop' with list of pops - if 'eachPop' in include: + if 'eachPop' in include: include.remove('eachPop') for pop in sim.net.allPops: include.append(pop) # Y-axis label - if yaxis == 'rate': + if yaxis == 'rate': if norm: yaxisLabel = 'Normalized firing rate' else: yaxisLabel = 'Avg cell firing rate (Hz)' - elif yaxis == 'count': + elif yaxis == 'count': if norm: yaxisLabel = 'Normalized spike count' else: @@ -892,7 +908,7 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize # create fig fig,ax1 = plt.subplots(figsize=figSize) fontsiz = 12 - + # Plot separate line for each entry in include for iplot,subset in enumerate(include): cells, cellGids, netStimLabels = getCellsInclude([subset]) @@ -904,7 +920,7 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize spkinds,spkts = list(zip(*[(spkgid,spkt) for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids])) except: spkinds,spkts = [],[] - else: + else: spkinds,spkts = [],[] # Add NetStim spikes @@ -916,7 +932,7 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize for stimLabel,stimSpks in stims.items() for spk in stimSpks if stimLabel == netStimLabel] if len(netStimSpks) > 0: lastInd = max(spkinds) if len(spkinds)>0 else 0 - spktsNew = netStimSpks + spktsNew = netStimSpks spkindsNew = [lastInd+1+i for i in range(len(netStimSpks))] spkts.extend(spktsNew) spkinds.extend(spkindsNew) @@ -924,15 +940,15 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize histo = np.histogram(spkts, bins = np.arange(timeRange[0], timeRange[1], binSize)) histoT = histo[1][:-1]+binSize/2 - histoCount = histo[0] + histoCount = histo[0] - if yaxis=='rate': + if yaxis=='rate': histoCount = histoCount * (1000.0 / binSize) / (len(cellGids)+numNetStims) # convert to firing rate if filtFreq: from scipy import signal fs = 1000.0/binSize - nyquist = fs/2.0 + nyquist = fs/2.0 if isinstance(filtFreq, list): # bandpass Wn = [filtFreq[0]/nyquist, filtFreq[1]/nyquist] b, a = signal.butter(filtOrder, Wn, btype='bandpass') @@ -946,23 +962,23 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize if smooth: histoCount = _smooth1d(histoCount, smooth)[:len(histoT)] - + histData.append(histoCount) - color = popColors[subset] if subset in popColors else colorList[iplot%len(colorList)] + color = popColors[subset] if subset in popColors else colorList[iplot%len(colorList)] - if not overlay: + if not overlay: plt.subplot(len(include),1,iplot+1) # if subplot, create new subplot plt.title (str(subset), fontsize=fontsiz) color = 'blue' - + if graphType == 'line': plt.plot (histoT, histoCount, linewidth=1.0, color = color) elif graphType == 'bar': #plt.bar(histoT, histoCount, width = binSize, color = color, fill=False) plt.plot (histoT, histoCount, linewidth=1.0, color = color, ls='steps') - if iplot == 0: + if iplot == 0: plt.xlabel('Time (ms)', fontsize=fontsiz) plt.ylabel(yaxisLabel, fontsize=fontsiz) # add yaxis in opposite side plt.xlim(timeRange) @@ -976,7 +992,7 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize # Add legend if overlay: for i,subset in enumerate(include): - color = popColors[subset] if subset in popColors else colorList[i%len(colorList)] + color = popColors[subset] if subset in popColors else colorList[i%len(colorList)] plt.plot(0,0,color=color,label=str(subset)) plt.legend(fontsize=fontsiz, bbox_to_anchor=(1.04, 1), loc=2, borderaxespad=0.) maxLabelLen = min(10,max([len(str(l)) for l in include])) @@ -986,31 +1002,31 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize if axis == 'off': ax = plt.gca() scalebarLoc = kwargs.get('scalebarLoc', 7) - round_to_n = lambda x, n, m: int(np.round(round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1)) / m)) * m + round_to_n = lambda x, n, m: int(np.round(round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1)) / m)) * m sizex = round_to_n((timeRange[1]-timeRange[0])/10.0, 1, 50) - add_scalebar(ax, hidex=False, hidey=True, matchx=False, matchy=True, sizex=sizex, sizey=None, - unitsx='ms', unitsy='Hz', scalex=1, scaley=1, loc=scalebarLoc, pad=2, borderpad=0.5, sep=4, prop=None, barcolor="black", barwidth=3) + add_scalebar(ax, hidex=False, hidey=True, matchx=False, matchy=True, sizex=sizex, sizey=None, + unitsx='ms', unitsy='Hz', scalex=1, scaley=1, loc=scalebarLoc, pad=2, borderpad=0.5, sep=4, prop=None, barcolor="black", barwidth=3) plt.axis(axis) # save figure data if saveData: figData = {'histData': histData, 'histT': histoT, 'include': include, 'timeRange': timeRange, 'binSize': binSize, 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, 'spikeHist') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'spikeHist.png' plt.savefig(filename, dpi=dpi) - # show fig + # show fig if showFig: _showFigure() - return fig, {'include': include, 'histData': histData, 'histoT': histoT, 'timeRange': timeRange} + return fig, {'histData': histData, 'histoT': histoT} @@ -1020,10 +1036,10 @@ def plotSpikeHist (include = ['allCells', 'eachPop'], timeRange = None, binSize #@exception def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRange = None, graphType='boxplot', stats = ['rate', 'isicv'], bins = 50, popColors = [], histlogy = False, histlogx = False, histmin = 0.0, density = False, includeRate0=False, legendLabels = None, normfit = False, - fontsize=14, histShading=True, xlim = None, dpi = 100, figSize = (6,8), saveData = None, saveFig = None, showFig = True, **kwargs): - ''' + fontsize=14, histShading=True, xlim = None, dpi = 100, figSize = (6,8), saveData = None, saveFig = None, showFig = True, **kwargs): + ''' Plot spike histogram - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. Note: one line per item, not grouped (default: ['allCells', 'eachPop']) - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - graphType ('boxplot', 'histogram'): Type of graph to use (default: 'boxplot') @@ -1059,7 +1075,7 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang xlabels = {'rate': 'Rate (Hz)', 'isicv': 'Irregularity (ISI CV)', 'sync': 'Synchrony', ' pairsync': 'Pairwise synchrony'} # Replace 'eachPop' with list of pops - if 'eachPop' in include: + if 'eachPop' in include: include.remove('eachPop') for pop in sim.net.allPops: include.append(pop) @@ -1070,7 +1086,7 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang for stat in stats: # create fig fig,ax1 = plt.subplots(figsize=figSize) - fontsiz = fontsize + fontsiz = fontsize xlabel = xlabels[stat] statData = [] @@ -1092,11 +1108,11 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang # Select cells to include if len(cellGids) > 0: try: - spkinds,spkts = list(zip(*[(spkgid,spkt) for spkgid,spkt in + spkinds,spkts = list(zip(*[(spkgid,spkt) for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids])) except: spkinds,spkts = [],[] - else: + else: spkinds,spkts = [],[] # Add NetStim spikes @@ -1105,23 +1121,23 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang if 'stims' in sim.allSimData: for netStimLabel in netStimLabels: netStimSpks = [spk for cell,stims in sim.allSimData['stims'].items() \ - for stimLabel,stimSpks in stims.items() + for stimLabel,stimSpks in stims.items() for spk in stimSpks if stimLabel == netStimLabel] if len(netStimSpks) > 0: lastInd = max(spkinds) if len(spkinds)>0 else 0 - spktsNew = netStimSpks + spktsNew = netStimSpks spkindsNew = [lastInd+1+i for i in range(len(netStimSpks))] spkts.extend(spktsNew) spkinds.extend(spkindsNew) numNetStims += 1 try: - spkts,spkinds = list(zip(*[(spkt, spkind) for spkt, spkind in zip(spkts, spkinds) + spkts,spkinds = list(zip(*[(spkt, spkind) for spkt, spkind in zip(spkts, spkinds) if timeRange[0] <= spkt <= timeRange[1]])) except: pass # if scatter get gids and ynorm - if graphType == 'scatter': + if graphType == 'scatter': if includeRate0: gids = cellGids else: @@ -1136,10 +1152,10 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang toRate = 1e3/(timeRange[1]-timeRange[0]) if includeRate0: rates = [spkinds.count(gid)*toRate for gid in cellGids] \ - if len(spkinds)>0 else [0]*len(cellGids) #cellGids] #set(spkinds)] + if len(spkinds)>0 else [0]*len(cellGids) #cellGids] #set(spkinds)] else: rates = [spkinds.count(gid)*toRate for gid in set(spkinds)] \ - if len(spkinds)>0 else [0] #cellGids] #set(spkinds)] + if len(spkinds)>0 else [0] #cellGids] #set(spkinds)] statData.insert(0, rates) @@ -1147,23 +1163,23 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang # Inter-spike interval (ISI) coefficient of variation (CV) stats elif stat == 'isicv': import numpy as np - spkmat = [[spkt for spkind,spkt in zip(spkinds,spkts) if spkind==gid] + spkmat = [[spkt for spkind,spkt in zip(spkinds,spkts) if spkind==gid] for gid in set(spkinds)] isimat = [[t - s for s, t in zip(spks, spks[1:])] for spks in spkmat if len(spks)>10] - isicv = [np.std(x) / np.mean(x) if len(x)>0 else 0 for x in isimat] # if len(x)>0] - statData.insert(0, isicv) + isicv = [np.std(x) / np.mean(x) if len(x)>0 else 0 for x in isimat] # if len(x)>0] + statData.insert(0, isicv) # synchrony elif stat in ['sync', 'pairsync']: - try: - import pyspike + try: + import pyspike except: print("Error: plotSpikeStats() requires the PySpike python package \ to calculate synchrony (try: pip install pyspike)") return 0 - - spkmat = [pyspike.SpikeTrain([spkt for spkind,spkt in zip(spkinds,spkts) + + spkmat = [pyspike.SpikeTrain([spkt for spkind,spkt in zip(spkinds,spkts) if spkind==gid], timeRange) for gid in set(spkinds)] if stat == 'sync': # (SPIKE-Sync measure)' # see http://www.scholarpedia.org/article/Measures_of_spike_train_synchrony @@ -1172,27 +1188,27 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang elif stat == 'pairsync': # (SPIKE-Sync measure)' # see http://www.scholarpedia.org/article/Measures_of_spike_train_synchrony syncMat = np.mean(pyspike.spike_sync_matrix(spkmat), 0) - + statData.insert(0, syncMat) - colors.insert(0, popColors[subset] if subset in popColors + colors.insert(0, popColors[subset] if subset in popColors else colorList[iplot%len(colorList)]) # colors in inverse order # if 'allCells' included make it black if include[0] == 'allCells': #if graphType == 'boxplot': - colors.insert(len(include), (0.5,0.5,0.5)) # + colors.insert(len(include), (0.5,0.5,0.5)) # del colors[0] # boxplot if graphType == 'boxplot': meanpointprops = dict(marker=(5,1,0), markeredgecolor='black', markerfacecolor='white') labels = legendLabels if legendLabels else include - bp=plt.boxplot(statData, labels=labels, notch=False, sym='k+', meanprops=meanpointprops, + bp=plt.boxplot(statData, labels=labels, notch=False, sym='k+', meanprops=meanpointprops, whis=1.5, widths=0.6, vert=False, showmeans=True, patch_artist=True) plt.xlabel(xlabel, fontsize=fontsiz) - plt.ylabel('Population', fontsize=fontsiz) + plt.ylabel('Population', fontsize=fontsiz) icolor=0 borderColor = 'k' @@ -1234,27 +1250,27 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang nmax = 0 pdfmax = 0 binmax = 0 - for i,data in enumerate(statData): # fix + for i,data in enumerate(statData): # fix if histlogx: histbins = np.logspace(np.log10(histmin), np.log10(max(data)), bins) else: histbins = bins - if histmin: # min value + if histmin: # min value data = np.array(data) data = data[data>histmin] - + # if histlogy: # data = [np.log10(x) for x in data] if density: - weights = np.ones_like(data)/float(len(data)) - else: + weights = np.ones_like(data)/float(len(data)) + else: weights = np.ones_like(data) n, binedges,_ = plt.hist(data, bins=histbins, histtype='step', color=colors[i], linewidth=2, weights=weights)#, normed=1)#, normed=density)# weights=weights) if histShading: - plt.hist(data, bins=histbins, alpha=0.05, color=colors[i], linewidth=0, weights=weights) + plt.hist(data, bins=histbins, alpha=0.05, color=colors[i], linewidth=0, weights=weights) label = legendLabels[-i-1] if legendLabels else str(include[-i-1]) if histShading: plt.hist([-10], bins=histbins, fc=((colors[i][0], colors[i][1], colors[i][2],0.05)), edgecolor=colors[i], linewidth=2, label=label) @@ -1262,16 +1278,16 @@ def plotSpikeStats (include = ['allCells', 'eachPop'], statDataIn = {}, timeRang plt.hist([-10], bins=histbins, fc=((1,1,1),0), edgecolor=colors[i], linewidth=2, label=label) nmax = max(nmax, max(n)) binmax = max(binmax, binedges[-1]) - if histlogx: + if histlogx: plt.xscale('log') if normfit: def lognorm(meaninput, stdinput, binedges, n, popLabel, color): - from scipy import stats + from scipy import stats M = float(meaninput) # Geometric mean == median s = float(stdinput) # Geometric standard deviation mu = np.log10(M) # Mean of log(X) - sigma = np.log10(s) # Standard deviation of log(X) + sigma = np.log10(s) # Standard deviation of log(X) shape = sigma # Scipy's shape parameter scale = np.power(10, mu) # Scipy's scale parameter x = [(binedges[i]+binedges[i+1])/2.0 for i in range(len(binedges)-1)] #np.linspace(histmin, 30, num=400) # values for x-axis @@ -1314,11 +1330,11 @@ def lognorm(meaninput, stdinput, binedges, n, popLabel, color): median, binedges, _ = stats.binned_statistic(ynorms, data, 'median', bins=bins) #p25 = lambda x: np.percentile(x, 25) #p75 = lambda x: np.percentile(x, 75) - + std, binedges, _ = stats.binned_statistic(ynorms, data, 'std', bins=bins) #per25, binedges, _ = stats.binned_statistic(ynorms, data, p25, bins=bins) #per75, binedges, _ = stats.binned_statistic(ynorms, data, p75, bins=bins) - + label = legendLabels[-i-1] if legendLabels else str(include[-i-1]) if kwargs.get('differentColor', None): threshold = kwargs['differentColor'][0] @@ -1328,7 +1344,7 @@ def lognorm(meaninput, stdinput, binedges, n, popLabel, color): else: plt.scatter(ynorms, data, color=[0/255.0,215/255.0,255/255.0], label=label, s=2) #[88/255.0,204/255.0,20/255.0] binstep = binedges[1]-binedges[0] - bincenters = [b+binstep/2 for b in binedges[:-1]] + bincenters = [b+binstep/2 for b in binedges[:-1]] plt.errorbar(bincenters, mean, yerr=std, color=[6/255.0,70/255.0,163/255.0], fmt = 'o-',capthick=1, capsize=5) #[44/255.0,53/255.0,127/255.0] #plt.errorbar(bincenters, mean, yerr=[mean-per25,per75-mean], fmt='go-',capthick=1, capsize=5) ylims=plt.ylim() @@ -1342,7 +1358,7 @@ def lognorm(meaninput, stdinput, binedges, n, popLabel, color): # elif graphType == 'bar': # print range(1, len(statData)+1), statData - # plt.bar(range(1, len(statData)+1), statData, tick_label=include[::-1], + # plt.bar(range(1, len(statData)+1), statData, tick_label=include[::-1], # orientation='horizontal', colors=colors) try: @@ -1357,17 +1373,17 @@ def lognorm(meaninput, stdinput, binedges, n, popLabel, color): _saveFigData(figData, saveData, 'spikeStats_'+stat) # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig+'_'+'spikeStat_'+graphType+'_'+stat+'.png' else: filename = sim.cfg.filename+'_'+'spikeStat_'+graphType+'_'+stat+'.png' plt.savefig(filename, dpi=dpi) - # show fig + # show fig if showFig: _showFigure() - return fig, {'include': include, 'statData': statData, 'gidsData':gidsData, 'ynormsData':ynormsData} + return fig, {'statData': statData, 'gidsData':gidsData, 'ynormsData':ynormsData} @@ -1375,11 +1391,11 @@ def lognorm(meaninput, stdinput, binedges, n, popLabel, color): ## Plot spike histogram ###################################################################################################################################################### @exception -def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = 5, maxFreq = 100, NFFT = 256, noverlap = 128, smooth = 0, overlay=True, ylim = None, - popColors = {}, figSize = (10,8), saveData = None, saveFig = None, showFig = True): - ''' +def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = 5, maxFreq = 100, NFFT = 256, noverlap = 128, smooth = 0, overlay=True, ylim = None, + popColors = {}, figSize = (10,8), saveData = None, saveFig = None, showFig = True): + ''' Plot firing rate power spectral density (PSD) - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of data series to include. Note: one line per item, not grouped (default: ['allCells', 'eachPop']) - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - binSize (int): Size in ms of spike bins (default: 5) @@ -1403,9 +1419,9 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = from . import sim print('Plotting firing rate power spectral density (PSD) ...') - + # Replace 'eachPop' with list of pops - if 'eachPop' in include: + if 'eachPop' in include: include.remove('eachPop') for pop in sim.net.allPops: include.append(pop) @@ -1418,11 +1434,11 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = # create fig fig,ax1 = plt.subplots(figsize=figSize) fontsiz = 12 - + allPower, allSignal, allFreqs=[], [], [] # Plot separate line for each entry in include for iplot,subset in enumerate(include): - cells, cellGids, netStimLabels = getCellsInclude([subset]) + cells, cellGids, netStimLabels = getCellsInclude([subset]) numNetStims = 0 # Select cells to include @@ -1431,7 +1447,7 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = spkinds,spkts = list(zip(*[(spkgid,spkt) for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids])) except: spkinds,spkts = [],[] - else: + else: spkinds,spkts = [],[] @@ -1444,7 +1460,7 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = for stimLabel,stimSpks in stims.items() for spk in stimSpks if stimLabel == netStimLabel] if len(netStimSpks) > 0: lastInd = max(spkinds) if len(spkinds)>0 else 0 - spktsNew = netStimSpks + spktsNew = netStimSpks spkindsNew = [lastInd+1+i for i in range(len(netStimSpks))] spkts.extend(spktsNew) spkinds.extend(spkindsNew) @@ -1452,20 +1468,20 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = histo = np.histogram(spkts, bins = np.arange(timeRange[0], timeRange[1], binSize)) histoT = histo[1][:-1]+binSize/2 - histoCount = histo[0] + histoCount = histo[0] histoCount = histoCount * (1000.0 / binSize) / (len(cellGids)+numNetStims) # convert to rates histData.append(histoCount) - color = popColors[subset] if isinstance(subset, (str, tuple)) and subset in popColors else colorList[iplot%len(colorList)] + color = popColors[subset] if isinstance(subset, (str, tuple)) and subset in popColors else colorList[iplot%len(colorList)] - if not overlay: + if not overlay: plt.subplot(len(include),1,iplot+1) # if subplot, create new subplot title (str(subset), fontsize=fontsiz) color = 'blue' - + Fs = 1000.0/binSize # ACTUALLY DEPENDS ON BIN WINDOW!!! RATE NOT SPIKE! - power = mlab.psd(histoCount, Fs=Fs, NFFT=NFFT, detrend=mlab.detrend_none, window=mlab.window_hanning, + power = mlab.psd(histoCount, Fs=Fs, NFFT=NFFT, detrend=mlab.detrend_none, window=mlab.window_hanning, noverlap=noverlap, pad_to=None, sides='default', scale_by_freq=None) if smooth: @@ -1494,7 +1510,7 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = # Add legend if overlay: for i,subset in enumerate(include): - color = popColors[subset] if isinstance(subset, str) and subset in popColors else colorList[i%len(colorList)] + color = popColors[subset] if isinstance(subset, str) and subset in popColors else colorList[i%len(colorList)] plt.plot(0,0,color=color,label=str(subset)) plt.legend(fontsize=fontsiz, loc=1)#, bbox_to_anchor=(1.04, 1), loc=2, borderaxespad=0.) maxLabelLen = min(10,max([len(str(l)) for l in include])) @@ -1505,21 +1521,21 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = if saveData: figData = {'histData': histData, 'histT': histoT, 'include': include, 'timeRange': timeRange, 'binSize': binSize, 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, 'spikeHist') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'spikePSD.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() - return fig, {'allSignal': allSignal, 'allPower': allPower, 'allFreqs':allFreqs} + return fig, {'allSignal':allSignal, 'allPower':allPower, 'allFreqs':allFreqs} @@ -1528,21 +1544,21 @@ def plotRatePSD (include = ['allCells', 'eachPop'], timeRange = None, binSize = ###################################################################################################################################################### @exception def plotTraces (include = None, timeRange = None, overlay = False, oneFigPer = 'cell', rerun = False, colors = None, ylim = None, axis='on', - figSize = (10,8), saveData = None, saveFig = None, showFig = True): - ''' + figSize = (10,8), saveData = None, saveFig = None, showFig = True): + ''' Plot recorded traces - - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of cells for which to plot + - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of cells for which to plot the recorded traces (default: []) - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - overlay (True|False): Whether to overlay the data lines or plot in separate subplots (default: False) - - oneFigPer ('cell'|'trace'): Whether to plot one figure per cell (showing multiple traces) + - oneFigPer ('cell'|'trace'): Whether to plot one figure per cell (showing multiple traces) or per trace (showing multiple cells) (default: 'cell') - rerun (True|False): rerun simulation so new set of cells gets recorded (default: False) - colors (list): List of normalized RGB colors to use for traces - ylim (list): Y-axis limits - axis ('on'|'off'): Whether to show axis or not; if not, then a scalebar is included (default: 'on') - figSize ((width, height)): Size of figure (default: (10,8)) - - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; + - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; if set to True uses filename from simConfig (default: None) - saveFig (None|True|'fileName'): File name where to save the figure; if set to True uses filename from simConfig (default: None) @@ -1559,15 +1575,15 @@ def plotTraces (include = None, timeRange = None, overlay = False, oneFigPer = ' include = sim.cfg.analysis['plotTraces']['include'] + sim.cfg.recordCells else: include = sim.cfg.recordCells - + global colorList - if isinstance(colors, list): + if isinstance(colors, list): colorList2 = colors else: colorList2 = colorList # rerun simulation so new include cells get recorded from - if rerun: + if rerun: cellsRecord = [cell.gid for cell in sim.getCellsList(include)] for cellRecord in cellsRecord: if cellRecord not in sim.cfg.recordCells: @@ -1608,19 +1624,19 @@ def plotFigPerTrace(subGids): plt.xlim(timeRange) if ylim: plt.ylim(ylim) plt.title('Cell %d, Pop %s '%(int(gid), gidPops[gid])) - + if axis == 'off': # if no axis, add scalebar ax = plt.gca() sizex = (timeRange[1]-timeRange[0])/20.0 # yl = plt.ylim() # plt.ylim(yl[0]-0.2*(yl[1]-yl[0]), yl[1]) - add_scalebar(ax, hidex=False, hidey=True, matchx=False, matchy=True, sizex=sizex, sizey=None, - unitsx='ms', unitsy='mV', scalex=1, scaley=1, loc=1, pad=-1, borderpad=0.5, sep=4, prop=None, barcolor="black", barwidth=3) - plt.axis(axis) + add_scalebar(ax, hidex=False, hidey=True, matchx=False, matchy=True, sizex=sizex, sizey=None, + unitsx='ms', unitsy='mV', scalex=1, scaley=1, loc=1, pad=-1, borderpad=0.5, sep=4, prop=None, barcolor="black", barwidth=3) + plt.axis(axis) if overlay: #maxLabelLen = 10 - #plt.subplots_adjust(right=(0.9-0.012*maxLabelLen)) + #plt.subplots_adjust(right=(0.9-0.012*maxLabelLen)) #plt.legend(fontsize=fontsiz, bbox_to_anchor=(1.04, 1), loc=2, borderaxespad=0.) plt.legend() @@ -1651,8 +1667,8 @@ def plotFigPerTrace(subGids): plt.xlim(timeRange) if ylim: plt.ylim(ylim) if itrace==0: plt.title('Cell %d, Pop %s '%(int(gid), gidPops[gid])) - - if overlay: + + if overlay: #maxLabelLen = 10 #plt.subplots_adjust(right=(0.9-0.012*maxLabelLen)) plt.legend()#fontsize=fontsiz, bbox_to_anchor=(1.04, 1), loc=2, borderaxespad=0.) @@ -1662,10 +1678,10 @@ def plotFigPerTrace(subGids): sizex = timeRange[1]-timeRange[0]/20 yl = plt.ylim() plt.ylim(yl[0]-0.2*(yl[1]-yl[0]), yl[1]) # leave space for scalebar? - add_scalebar(ax, hidex=False, hidey=True, matchx=False, matchy=True, sizex=sizex, sizey=None, - unitsx='ms', unitsy='mV', scalex=1, scaley=1, loc=4, pad=10, borderpad=0.5, sep=3, prop=None, barcolor="black", barwidth=2) - plt.axis(axis) - + add_scalebar(ax, hidex=False, hidey=True, matchx=False, matchy=True, sizex=sizex, sizey=None, + unitsx='ms', unitsy='mV', scalex=1, scaley=1, loc=4, pad=10, borderpad=0.5, sep=3, prop=None, barcolor="black", barwidth=2) + plt.axis(axis) + # Plot one fig per trace elif oneFigPer == 'trace': plotFigPerTrace(cellGids) @@ -1686,11 +1702,11 @@ def plotFigPerTrace(subGids): if saveData: figData = {'tracesData': tracesData, 'include': include, 'timeRange': timeRange, 'oneFigPer': oneFigPer, 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, 'traces') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: @@ -1702,10 +1718,10 @@ def plotFigPerTrace(subGids): else: plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() - return figs, {'tracesData': tracesData, 'include': include} + return figs, {} def invertDictMapping(d): """ Invert mapping of dictionary (i.e. map values to list of keys) """ @@ -1720,16 +1736,16 @@ def invertDictMapping(d): ## Plot cell shape ###################################################################################################################################################### @exception -def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, showElectrodes = False, synStyle = '.', synSiz=3, dist=0.6, cvar=None, cvals=None, - iv=False, ivprops=None, includeAxon=True, bkgColor = None, figSize = (10,8), saveData = None, dpi = 300, saveFig = None, showFig = True): - ''' +def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, showElectrodes = False, synStyle = '.', synSiz=3, dist=0.6, cvar=None, cvals=None, + iv=False, ivprops=None, includeAxon=True, bkgColor = None, figSize = (10,8), saveData = None, dpi = 300, saveFig = None, showFig = True): + ''' Plot 3D cell shape using NEURON Interview PlotShape - - includePre: (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of presynaptic cells to consider + - includePre: (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of presynaptic cells to consider when plotting connections (default: ['all']) - includePost: (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): List of cells to show shape of (default: ['all']) - showSyns (True|False): Show synaptic connections in 3D view (default: False) - showElectrodes (True|False): Show LFP electrodes in 3D view (default: False) - - synStyle: Style of marker to show synapses (default: '.') + - synStyle: Style of marker to show synapses (default: '.') - dist: 3D distance (like zoom) (default: 0.6) - synSize: Size of marker to show synapses (default: 3) - cvar: ('numSyns'|'weightNorm') Variable to represent in shape plot (default: None) @@ -1737,9 +1753,9 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh - iv: Use NEURON Interviews (instead of matplotlib) to show shape plot (default: None) - ivprops: Dict of properties to plot using Interviews (default: None) - includeAxon: Include axon in shape plot (default: True) - - bkgColor (list/tuple with 4 floats): RGBA list/tuple with bakcground color eg. (0.5, 0.2, 0.1, 1.0) (default: None) + - bkgColor (list/tuple with 4 floats): RGBA list/tuple with bakcground color eg. (0.5, 0.2, 0.1, 1.0) (default: None) - figSize ((width, height)): Size of figure (default: (10,8)) - - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; + - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; if set to True uses filename from simConfig (default: None) - saveFig (None|True|'fileName'): File name where to save the figure; if set to True uses filename from simConfig (default: None) @@ -1762,9 +1778,9 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh if not iv: # plot using Python instead of interviews from mpl_toolkits.mplot3d import Axes3D from netpyne.support import morphology as morph # code adapted from https://github.com/ahwillia/PyNeuron-Toolbox - + # create secList from include - + secs = None # Set cvals and secs @@ -1774,7 +1790,7 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh # weighNorm if cvar == 'weightNorm': for cellPost in cellsPost: - cellSecs = list(cellPost.secs.values()) if includeAxon else [s for s in list(cellPost.secs.values()) if 'axon' not in s['hSec'].hname()] + cellSecs = list(cellPost.secs.values()) if includeAxon else [s for s in list(cellPost.secs.values()) if 'axon' not in s['hSec'].hname()] for sec in cellSecs: if 'weightNorm' in sec: secs.append(sec['hSec']) @@ -1798,24 +1814,24 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh cvals = np.array(cvals) if not secs: secs = [s['hSec'] for cellPost in cellsPost for s in list(cellPost.secs.values())] - if not includeAxon: + if not includeAxon: secs = [sec for sec in secs if 'axon' not in sec.hname()] # Plot shapeplot cbLabels = {'numSyns': 'number of synapses', 'weightNorm': 'weight scaling'} fig=plt.figure(figsize=figSize) shapeax = plt.subplot(111, projection='3d') - shapeax.elev=90 # 90 + shapeax.elev=90 # 90 shapeax.azim=-90 # -90 shapeax.dist=dist*shapeax.dist plt.axis('equal') cmap=plt.cm.jet #plt.cm.rainbow #plt.cm.jet #YlOrBr_r morph.shapeplot(h,shapeax, sections=secs, cvals=cvals, cmap=cmap) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - if not cvals==None and len(cvals)>0: + if not cvals==None and len(cvals)>0: sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=np.min(cvals), vmax=np.max(cvals))) sm._A = [] # fake up the array of the scalar mappable - cb = plt.colorbar(sm, fraction=0.15, shrink=0.5, pad=0.01, aspect=20) + cb = plt.colorbar(sm, fraction=0.15, shrink=0.5, pad=0.01, aspect=20) if cvar: cb.set_label(cbLabels[cvar], rotation=90) if bkgColor: @@ -1849,14 +1865,14 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh shapeax.set_xticklabels([]) # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_shape.png' plt.savefig(filename, dpi=dpi) - # show fig + # show fig if showFig: _showFigure() else: # Plot using Interviews @@ -1866,8 +1882,8 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh secList = h.SectionList() if not ivprops: ivprops = {'colorSecs': 1, 'colorSyns':2 ,'style': 'O', 'siz':5} - - for cell in [c for c in cellsPost]: + + for cell in [c for c in cellsPost]: for sec in list(cell.secs.values()): if 'axon' in sec['hSec'].hname() and not includeAxon: continue sec['hSec'].push() @@ -1881,14 +1897,14 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh # colorsPre[prePop] = colorCounter # find synMech using conn['loc'], conn['sec'] and conn['synMech'] - fig.point_mark(synMech['hSyn'], ivprops['colorSyns'], ivprops['style'], ivprops['siz']) + fig.point_mark(synMech['hSyn'], ivprops['colorSyns'], ivprops['style'], ivprops['siz']) fig.observe(secList) fig.color_list(secList, ivprops['colorSecs']) fig.flush() fig.show(0) # show real diam # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: @@ -1903,13 +1919,13 @@ def plotShape (includePost = ['all'], includePre = ['all'], showSyns = False, sh ## Plot LFP (time-resolved, power spectral density, time-frequency and 3D locations) ###################################################################################################################################################### #@exception -def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectrogram', 'locations'], timeRange = None, NFFT = 256, noverlap = 128, +def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectrogram', 'locations'], timeRange = None, NFFT = 256, noverlap = 128, nperseg = 256, maxFreq = 100, smooth = 0, separation = 1.0, includeAxon=True, logx=False, logy=False, norm=False, dpi = 200, overlay=False, filtFreq = False, filtOrder=3, detrend=False, - colors = None, figSize = (8,8), saveData = None, saveFig = None, showFig = True): - ''' + colors = None, figSize = (8,8), saveData = None, saveFig = None, showFig = True): + ''' Plot LFP - electrodes (list): List of electrodes to include; 'avg'=avg of all electrodes; 'all'=each electrode separately (default: ['avg', 'all']) - - plots (list): list of plot types to show (default: ['timeSeries', 'PSD', 'timeFreq', 'locations']) + - plots (list): list of plot types to show (default: ['timeSeries', 'PSD', 'timeFreq', 'locations']) - timeRange ([start:stop]): Time range of spikes shown; if None shows all (default: None) - NFFT (int, power of 2): Number of data points used in each block for the PSD and time-freq FFT (default: 256) - noverlap (int, 1: plt.text(timeRange[0]-0.14*(timeRange[1]-timeRange[0]), (len(electrodes)*ydisp)/2.0, 'LFP electrode', color='k', ha='left', va='bottom', fontsize=fontsiz, rotation=90) plt.ylim(-offset, (len(electrodes))*ydisp) - else: + else: plt.suptitle('LFP Signal', fontsize=fontsiz, fontweight='bold') ax.invert_yaxis() plt.xlabel('time (ms)', fontsize=fontsiz) @@ -2021,7 +2037,7 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro plt.subplots_adjust(bottom=0.1, top=1.0, right=1.0) # calculate scalebar size and add scalebar - round_to_n = lambda x, n, m: int(np.ceil(round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1)) / m)) * m + round_to_n = lambda x, n, m: int(np.ceil(round(x, -int(np.floor(np.log10(abs(x)))) + (n - 1)) / m)) * m scaley = 1000.0 # values in mV but want to convert to uV m = 10.0 sizey = 100/scaley @@ -2033,13 +2049,13 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro m /= 10.0 labely = '%.3g $\mu$V'%(sizey*scaley)#)[1:] if len(electrodes) > 1: - add_scalebar(ax,hidey=True, matchy=False, hidex=False, matchx=False, sizex=0, sizey=-sizey, labely=labely, unitsy='$\mu$V', scaley=scaley, + add_scalebar(ax,hidey=True, matchy=False, hidex=False, matchx=False, sizex=0, sizey=-sizey, labely=labely, unitsy='$\mu$V', scaley=scaley, loc=3, pad=0.5, borderpad=0.5, sep=3, prop=None, barcolor="black", barwidth=2) else: - add_scalebar(ax, hidey=True, matchy=False, hidex=True, matchx=True, sizex=None, sizey=-sizey, labely=labely, unitsy='$\mu$V', scaley=scaley, + add_scalebar(ax, hidey=True, matchy=False, hidex=True, matchx=True, sizex=None, sizey=-sizey, labely=labely, unitsy='$\mu$V', scaley=scaley, unitsx='ms', loc=3, pad=0.5, borderpad=0.5, sep=3, prop=None, barcolor="black", barwidth=2) # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: @@ -2071,9 +2087,9 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro lfpPlot = lfp[:, elec] color = colors[i%len(colors)] lw=1.5 - + Fs = int(1000.0/sim.cfg.recordStep) - power = mlab.psd(lfpPlot, Fs=Fs, NFFT=NFFT, detrend=mlab.detrend_none, window=mlab.window_hanning, + power = mlab.psd(lfpPlot, Fs=Fs, NFFT=NFFT, detrend=mlab.detrend_none, window=mlab.window_hanning, noverlap=noverlap, pad_to=None, sides='default', scale_by_freq=None) if smooth: @@ -2090,7 +2106,7 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro if len(electrodes) > 1 and not overlay: plt.title('Electrode %s'%(str(elec)), fontsize=fontsiz-2) plt.ylabel('dB/Hz', fontsize=fontsiz) - + # ALTERNATIVE PSD CALCULATION USING WELCH # from http://joelyancey.com/lfp-python-practice/ # from scipy import signal as spsig @@ -2118,7 +2134,7 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro #from IPython import embed; embed() # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: @@ -2131,16 +2147,16 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro numCols = np.round(len(electrodes) / maxPlots) + 1 figs.append(plt.figure(figsize=(figSize[0]*numCols, figSize[1]))) #t = np.arange(timeRange[0], timeRange[1], sim.cfg.recordStep) - + from scipy import signal as spsig logx_spec = [] - + for i,elec in enumerate(electrodes): if elec == 'avg': lfpPlot = np.mean(lfp, axis=1) elif isinstance(elec, Number) and elec <= sim.net.recXElectrode.nsites: lfpPlot = lfp[:, elec] - # creates spectrogram over a range of data + # creates spectrogram over a range of data # from: http://joelyancey.com/lfp-python-practice/ fs = int(1000.0/sim.cfg.recordStep) f, t_spec, x_spec = spsig.spectrogram(lfpPlot, fs=fs, window='hanning', @@ -2175,9 +2191,9 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro plt.tight_layout() plt.suptitle('LFP spectrogram', size=fontsiz, fontweight='bold') plt.subplots_adjust(bottom=0.08, top=0.90) - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: @@ -2195,12 +2211,12 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro for secName, sec in cell.secs.items(): nseg = sec['hSec'].nseg #.geom.nseg if 'axon' in secName: - for j in range(i,i+nseg): del trSegs[j] + for j in range(i,i+nseg): del trSegs[j] i+=nseg - cvals.extend(trSegs) - + cvals.extend(trSegs) + includePost = [c.gid for c in sim.net.compartCells] - fig = sim.analysis.plotShape(includePost=includePost, showElectrodes=electrodes, cvals=cvals, includeAxon=includeAxon, dpi=dpi, saveFig=saveFig, showFig=showFig, figSize=figSize) + fig = sim.analysis.plotShape(includePost=includePost, showElectrodes=electrodes, cvals=cvals, includeAxon=includeAxon, dpi=dpi, saveFig=saveFig, showFig=showFig, figSize=figSize)[0] figs.append(fig) @@ -2208,14 +2224,14 @@ def plotLFP (electrodes = ['avg', 'all'], plots = ['timeSeries', 'PSD', 'spectro if saveData: figData = {'LFP': lfp, 'electrodes': electrodes, 'timeRange': timeRange, 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, 'lfp') - # show fig + # show fig if showFig: _showFigure() - return figs, {'LFP': lfp, 'electrodes': electrodes, 'saveData': saveData} + return figs, data ###################################################################################################################################################### ## Support function for plotConn() - calculate conn using data from sim object @@ -2231,7 +2247,7 @@ def list_of_dict_unique_by_key(seq, key): return [x for x in seq if x[key] not in seen and not seen_add(x[key])] # adapt indices/keys based on compact vs long conn format - if sim.cfg.compactConnFormat: + if sim.cfg.compactConnFormat: connsFormat = sim.cfg.compactConnFormat # set indices of fields to read compact format (no keys) @@ -2241,14 +2257,14 @@ def list_of_dict_unique_by_key(seq, key): weightIndex = connsFormat.index('weight') if 'weight' in connsFormat else missing.append('weight') delayIndex = connsFormat.index('delay') if 'delay' in connsFormat else missing.append('delay') preLabelIndex = connsFormat.index('preLabel') if 'preLabel' in connsFormat else -1 - + if len(missing) > 0: print(" Error: cfg.compactConnFormat missing:") print(missing) - return None, None, None - else: + return None, None, None + else: # using long conn format (dict) - preGidIndex = 'preGid' + preGidIndex = 'preGid' synMechIndex = 'synMech' weightIndex = 'weight' delayIndex = 'delay' @@ -2257,18 +2273,18 @@ def list_of_dict_unique_by_key(seq, key): # Calculate pre and post cells involved cellsPre, cellGidsPre, netStimPopsPre = getCellsInclude(includePre) if includePre == includePost: - cellsPost, cellGidsPost, netStimPopsPost = cellsPre, cellGidsPre, netStimPopsPre + cellsPost, cellGidsPost, netStimPopsPost = cellsPre, cellGidsPre, netStimPopsPre else: - cellsPost, cellGidsPost, netStimPopsPost = getCellsInclude(includePost) + cellsPost, cellGidsPost, netStimPopsPost = getCellsInclude(includePost) if isinstance(synMech, str): synMech = [synMech] # make sure synMech is a list - + # Calculate matrix if grouped by cell - if groupBy == 'cell': - if feature in ['weight', 'delay', 'numConns']: + if groupBy == 'cell': + if feature in ['weight', 'delay', 'numConns']: connMatrix = np.zeros((len(cellGidsPre), len(cellGidsPost))) countMatrix = np.zeros((len(cellGidsPre), len(cellGidsPost))) - else: + else: print('Conn matrix with groupBy="cell" only supports features= "weight", "delay" or "numConns"') return fig cellIndsPre = {cell['gid']: ind for ind,cell in enumerate(cellsPre)} @@ -2278,10 +2294,10 @@ def list_of_dict_unique_by_key(seq, key): if len(cellsPre) > 0 and len(cellsPost) > 0: if orderBy not in cellsPre[0]['tags'] or orderBy not in cellsPost[0]['tags']: # if orderBy property doesn't exist or is not numeric, use gid orderBy = 'gid' - elif not isinstance(cellsPre[0]['tags'][orderBy], Number) or not isinstance(cellsPost[0]['tags'][orderBy], Number): - orderBy = 'gid' - - if orderBy == 'gid': + elif not isinstance(cellsPre[0]['tags'][orderBy], Number) or not isinstance(cellsPost[0]['tags'][orderBy], Number): + orderBy = 'gid' + + if orderBy == 'gid': yorderPre = [cell[orderBy] for cell in cellsPre] yorderPost = [cell[orderBy] for cell in cellsPost] else: @@ -2302,7 +2318,7 @@ def list_of_dict_unique_by_key(seq, key): for cell in cellsPost: # for each postsyn cell if synOrConn=='syn': - cellConns = cell['conns'] # include all synapses + cellConns = cell['conns'] # include all synapses else: cellConns = list_of_dict_unique_by_key(cell['conns'], preGidIndex) @@ -2317,14 +2333,14 @@ def list_of_dict_unique_by_key(seq, key): connMatrix[cellIndsPre[conn[preGidIndex]], cellIndsPost[cell['gid']]] += conn[featureIndex] countMatrix[cellIndsPre[conn[preGidIndex]], cellIndsPost[cell['gid']]] += 1 - if feature in ['weight', 'delay']: connMatrix = connMatrix / countMatrix - elif feature in ['numConns']: connMatrix = countMatrix + if feature in ['weight', 'delay']: connMatrix = connMatrix / countMatrix + elif feature in ['numConns']: connMatrix = countMatrix - pre, post = cellsPre, cellsPost + pre, post = cellsPre, cellsPost # Calculate matrix if grouped by pop - elif groupBy == 'pop': - + elif groupBy == 'pop': + # get list of pops popsTempPre = list(set([cell['tags']['pop'] for cell in cellsPre])) popsPre = [pop for pop in sim.net.allPops if pop in popsTempPre]+netStimPopsPre @@ -2337,14 +2353,14 @@ def list_of_dict_unique_by_key(seq, key): popsTempPost = list(set([cell['tags']['pop'] for cell in cellsPost])) popsPost = [pop for pop in sim.net.allPops if pop in popsTempPost]+netStimPopsPost popIndsPost = {pop: ind for ind,pop in enumerate(popsPost)} - + # initialize matrices - if feature in ['weight', 'strength']: + if feature in ['weight', 'strength']: weightMatrix = np.zeros((len(popsPre), len(popsPost))) - elif feature == 'delay': + elif feature == 'delay': delayMatrix = np.zeros((len(popsPre), len(popsPost))) countMatrix = np.zeros((len(popsPre), len(popsPost))) - + # calculate max num conns per pre and post pair of pops numCellsPopPre = {} for pop in popsPre: @@ -2367,17 +2383,17 @@ def list_of_dict_unique_by_key(seq, key): if feature == 'convergence': maxPostConnMatrix = np.zeros((len(popsPre), len(popsPost))) if feature == 'divergence': maxPreConnMatrix = np.zeros((len(popsPre), len(popsPost))) for prePop in popsPre: - for postPop in popsPost: + for postPop in popsPost: if numCellsPopPre[prePop] == -1: numCellsPopPre[prePop] = numCellsPopPost[postPop] maxConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPre[prePop]*numCellsPopPost[postPop] if feature == 'convergence': maxPostConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPost[postPop] if feature == 'divergence': maxPreConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPre[prePop] - + # Calculate conn matrix for cell in cellsPost: # for each postsyn cell if synOrConn=='syn': - cellConns = cell['conns'] # include all synapses + cellConns = cell['conns'] # include all synapses else: cellConns = list_of_dict_unique_by_key(cell['conns'], preGidIndex) @@ -2390,16 +2406,16 @@ def list_of_dict_unique_by_key(seq, key): else: preCell = next((cell for cell in cellsPre if cell['gid']==conn[preGidIndex]), None) prePopLabel = preCell['tags']['pop'] if preCell else None - + if prePopLabel in popIndsPre: - if feature in ['weight', 'strength']: + if feature in ['weight', 'strength']: weightMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += conn[weightIndex] - elif feature == 'delay': - delayMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += conn[delayIndex] - countMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += 1 + elif feature == 'delay': + delayMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += conn[delayIndex] + countMatrix[popIndsPre[prePopLabel], popIndsPost[cell['tags']['pop']]] += 1 + + pre, post = popsPre, popsPost - pre, post = popsPre, popsPost - # Calculate matrix if grouped by numeric tag (eg. 'y') elif groupBy in sim.net.allCells[0]['tags'] and isinstance(sim.net.allCells[0]['tags'][groupBy], Number): if not isinstance(groupByIntervalPre, Number) or not isinstance(groupByIntervalPost, Number): @@ -2409,33 +2425,33 @@ def list_of_dict_unique_by_key(seq, key): # group cells by 'groupBy' feature (eg. 'y') in intervals of 'groupByInterval') cellValuesPre = [cell['tags'][groupBy] for cell in cellsPre] minValuePre = _roundFigures(groupByIntervalPre * np.floor(min(cellValuesPre) / groupByIntervalPre), 3) - maxValuePre = _roundFigures(groupByIntervalPre * np.ceil(max(cellValuesPre) / groupByIntervalPre), 3) + maxValuePre = _roundFigures(groupByIntervalPre * np.ceil(max(cellValuesPre) / groupByIntervalPre), 3) groupsPre = np.arange(minValuePre, maxValuePre, groupByIntervalPre) groupsPre = [_roundFigures(x,3) for x in groupsPre] if includePre == includePost: - groupsPost = groupsPre + groupsPost = groupsPre else: cellValuesPost = [cell['tags'][groupBy] for cell in cellsPost] minValuePost = _roundFigures(groupByIntervalPost * np.floor(min(cellValuesPost) / groupByIntervalPost), 3) - maxValuePost = _roundFigures(groupByIntervalPost * np.ceil(max(cellValuesPost) / groupByIntervalPost), 3) + maxValuePost = _roundFigures(groupByIntervalPost * np.ceil(max(cellValuesPost) / groupByIntervalPost), 3) groupsPost = np.arange(minValuePost, maxValuePost, groupByIntervalPost) groupsPost = [_roundFigures(x,3) for x in groupsPost] # only allow matrix sizes >= 2x2 [why?] - # if len(groupsPre) < 2 or len(groupsPost) < 2: + # if len(groupsPre) < 2 or len(groupsPost) < 2: # print 'groupBy %s with groupByIntervalPre %s and groupByIntervalPost %s results in <2 groups'%(str(groupBy), str(groupByIntervalPre), str(groupByIntervalPre)) # return # set indices for pre and post groups groupIndsPre = {group: ind for ind,group in enumerate(groupsPre)} groupIndsPost = {group: ind for ind,group in enumerate(groupsPost)} - + # initialize matrices - if feature in ['weight', 'strength']: + if feature in ['weight', 'strength']: weightMatrix = np.zeros((len(groupsPre), len(groupsPost))) - elif feature == 'delay': + elif feature == 'delay': delayMatrix = np.zeros((len(groupsPre), len(groupsPost))) countMatrix = np.zeros((len(groupsPre), len(groupsPost))) @@ -2443,9 +2459,9 @@ def list_of_dict_unique_by_key(seq, key): numCellsGroupPre = {} for groupPre in groupsPre: numCellsGroupPre[groupPre] = len([cell for cell in cellsPre if groupPre <= cell['tags'][groupBy] < (groupPre+groupByIntervalPre)]) - + if includePre == includePost: - numCellsGroupPost = numCellsGroupPre + numCellsGroupPost = numCellsGroupPre else: numCellsGroupPost = {} for groupPost in groupsPost: @@ -2456,16 +2472,16 @@ def list_of_dict_unique_by_key(seq, key): if feature == 'convergence': maxPostConnMatrix = np.zeros((len(groupsPre), len(groupsPost))) if feature == 'divergence': maxPreConnMatrix = np.zeros((len(groupsPre), len(groupsPost))) for preGroup in groupsPre: - for postGroup in groupsPost: + for postGroup in groupsPost: if numCellsGroupPre[preGroup] == -1: numCellsGroupPre[preGroup] = numCellsGroupPost[postGroup] maxConnMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] = numCellsGroupPre[preGroup]*numCellsGroupPost[postGroup] if feature == 'convergence': maxPostConnMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] = numCellsGroupPost[postGroup] if feature == 'divergence': maxPreConnMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] = numCellsGroupPre[preGroup] - + # Calculate conn matrix for cell in cellsPost: # for each postsyn cell if synOrConn=='syn': - cellConns = cell['conns'] # include all synapses + cellConns = cell['conns'] # include all synapses else: cellConns = list_of_dict_unique_by_key(cell['conns'], preGidIndex) @@ -2484,25 +2500,25 @@ def list_of_dict_unique_by_key(seq, key): postGroup = _roundFigures(groupByIntervalPost * np.floor(cell['tags'][groupBy] / groupByIntervalPost), 3) if preGroup in groupIndsPre: - if feature in ['weight', 'strength']: + if feature in ['weight', 'strength']: weightMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += conn[weightIndex] - elif feature == 'delay': - delayMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += conn[delayIndex] - countMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += 1 + elif feature == 'delay': + delayMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += conn[delayIndex] + countMatrix[groupIndsPre[preGroup], groupIndsPost[postGroup]] += 1 - - pre, post = groupsPre, groupsPost + + pre, post = groupsPre, groupsPost # no valid groupBy - else: + else: print('groupBy (%s) is not valid'%(str(groupBy))) return # normalize by number of postsyn cells if groupBy != 'cell': - if feature == 'weight': - connMatrix = weightMatrix / countMatrix # avg weight per conn (fix to remove divide by zero warning) - elif feature == 'delay': + if feature == 'weight': + connMatrix = weightMatrix / countMatrix # avg weight per conn (fix to remove divide by zero warning) + elif feature == 'delay': connMatrix = delayMatrix / countMatrix elif feature == 'numConns': connMatrix = countMatrix @@ -2525,10 +2541,10 @@ def list_of_dict_unique_by_key(seq, key): ###################################################################################################################################################### def __plotConnCalculateFromFile__(includePre, includePost, feature, orderBy, groupBy, groupByIntervalPre, groupByIntervalPost, synOrConn, synMech, connsFile, tagsFile): - + from . import sim import json - from time import time + from time import time def list_of_dict_unique_by_key(seq, index): seen = set() @@ -2552,7 +2568,7 @@ def list_of_dict_unique_by_key(seq, index): del connsTmp print('Finished loading; total time (s): %.2f'%(time()-start)) - + # find pre and post cells if tags and conns: cellGidsPre = getCellsIncludeTags(includePre, tags, tagsFormat) @@ -2561,7 +2577,7 @@ def list_of_dict_unique_by_key(seq, index): else: cellGidsPost = getCellsIncludeTags(includePost, tags, tagsFormat) else: - print('Error loading tags and conns from file') + print('Error loading tags and conns from file') return None, None, None @@ -2573,22 +2589,22 @@ def list_of_dict_unique_by_key(seq, index): weightIndex = connsFormat.index('weight') if 'weight' in connsFormat else missing.append('weight') delayIndex = connsFormat.index('delay') if 'delay' in connsFormat else missing.append('delay') preLabelIndex = connsFormat.index('preLabel') if 'preLabel' in connsFormat else -1 - + if len(missing) > 0: print("Missing:") print(missing) - return None, None, None + return None, None, None if isinstance(synMech, str): synMech = [synMech] # make sure synMech is a list - + # Calculate matrix if grouped by cell - if groupBy == 'cell': + if groupBy == 'cell': print('plotConn from file for groupBy=cell not implemented yet') - return None, None, None + return None, None, None # Calculate matrix if grouped by pop - elif groupBy == 'pop': - + elif groupBy == 'pop': + # get list of pops print(' Obtaining list of populations ...') popsPre = list(set([tags[gid][popIndex] for gid in cellGidsPre])) @@ -2602,14 +2618,14 @@ def list_of_dict_unique_by_key(seq, index): else: popsPost = list(set([tags[gid][popIndex] for gid in cellGidsPost])) popIndsPost = {pop: ind for ind,pop in enumerate(popsPost)} - + # initialize matrices - if feature in ['weight', 'strength']: + if feature in ['weight', 'strength']: weightMatrix = np.zeros((len(popsPre), len(popsPost))) - elif feature == 'delay': + elif feature == 'delay': delayMatrix = np.zeros((len(popsPre), len(popsPost))) countMatrix = np.zeros((len(popsPre), len(popsPost))) - + # calculate max num conns per pre and post pair of pops print(' Calculating max num conns for each pair of population ...') numCellsPopPre = {} @@ -2633,18 +2649,18 @@ def list_of_dict_unique_by_key(seq, index): if feature == 'convergence': maxPostConnMatrix = np.zeros((len(popsPre), len(popsPost))) if feature == 'divergence': maxPreConnMatrix = np.zeros((len(popsPre), len(popsPost))) for prePop in popsPre: - for postPop in popsPost: + for postPop in popsPost: if numCellsPopPre[prePop] == -1: numCellsPopPre[prePop] = numCellsPopPost[postPop] maxConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPre[prePop]*numCellsPopPost[postPop] if feature == 'convergence': maxPostConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPost[postPop] if feature == 'divergence': maxPreConnMatrix[popIndsPre[prePop], popIndsPost[postPop]] = numCellsPopPre[prePop] - + # Calculate conn matrix print(' Calculating weights, strength, prob, delay etc matrices ...') for postGid in cellGidsPost: # for each postsyn cell print(' cell %d'%(int(postGid))) if synOrConn=='syn': - cellConns = conns[postGid] # include all synapses + cellConns = conns[postGid] # include all synapses else: cellConns = list_of_dict_unique_by_index(conns[postGid], preGidIndex) @@ -2657,30 +2673,30 @@ def list_of_dict_unique_by_key(seq, index): else: preCellGid = next((gid for gid in cellGidsPre if gid==conn[preGidIndex]), None) prePopLabel = tags[preCellGid][popIndex] if preCellGid else None - + if prePopLabel in popIndsPre: - if feature in ['weight', 'strength']: + if feature in ['weight', 'strength']: weightMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += conn[weightIndex] - elif feature == 'delay': - delayMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += conn[delayIndex] - countMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += 1 + elif feature == 'delay': + delayMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += conn[delayIndex] + countMatrix[popIndsPre[prePopLabel], popIndsPost[tags[postGid][popIndex]]] += 1 + + pre, post = popsPre, popsPost - pre, post = popsPre, popsPost - # Calculate matrix if grouped by numeric tag (eg. 'y') elif groupBy in sim.net.allCells[0]['tags'] and isinstance(sim.net.allCells[0]['tags'][groupBy], Number): print('plotConn from file for groupBy=[arbitrary property] not implemented yet') - return None, None, None + return None, None, None # no valid groupBy - else: + else: print('groupBy (%s) is not valid'%(str(groupBy))) return if groupBy != 'cell': - if feature == 'weight': - connMatrix = weightMatrix / countMatrix # avg weight per conn (fix to remove divide by zero warning) - elif feature == 'delay': + if feature == 'weight': + connMatrix = weightMatrix / countMatrix # avg weight per conn (fix to remove divide by zero warning) + elif feature == 'delay': connMatrix = delayMatrix / countMatrix elif feature == 'numConns': connMatrix = countMatrix @@ -2702,12 +2718,12 @@ def list_of_dict_unique_by_key(seq, index): ###################################################################################################################################################### @exception def plotConn (includePre = ['all'], includePost = ['all'], feature = 'strength', orderBy = 'gid', figSize = (10,10), groupBy = 'pop', groupByIntervalPre = None, groupByIntervalPost = None, - graphType = 'matrix', synOrConn = 'syn', synMech = None, connsFile = None, tagsFile = None, clim = None, saveData = None, saveFig = None, showFig = True): - ''' + graphType = 'matrix', synOrConn = 'syn', synMech = None, connsFile = None, tagsFile = None, clim = None, saveData = None, saveFig = None, showFig = True): + ''' Plot network connectivity - includePre (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Cells to show (default: ['all']) - includePost (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Cells to show (default: ['all']) - - feature ('weight'|'delay'|'numConns'|'probability'|'strength'|'convergence'|'divergence'): Feature to show in connectivity matrix; + - feature ('weight'|'delay'|'numConns'|'probability'|'strength'|'convergence'|'divergence'): Feature to show in connectivity matrix; the only features applicable to groupBy='cell' are 'weight', 'delay' and 'numConns'; 'strength' = weight * probability (default: 'strength') - groupBy ('pop'|'cell'|'y'|: Show matrix for individual cells, populations, or by other numeric tag such as 'y' (default: 'pop') - groupByInterval (int or float): Interval of groupBy feature to group cells by in conn matrix, e.g. 100 to group by cortical depth in steps of 100 um (default: None) @@ -2716,15 +2732,15 @@ def plotConn (includePre = ['all'], includePost = ['all'], feature = 'strength', - synOrConn ('syn'|'conn'): Use synapses or connections; note 1 connection can have multiple synapses (default: 'syn') - figSize ((width, height)): Size of figure (default: (10,10)) - synMech (['AMPA', 'GABAA',...]): Show results only for these syn mechs (default: None) - - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; + - saveData (None|True|'fileName'): File name where to save the final data used to generate the figure; if set to True uses filename from simConfig (default: None) - - saveFig (None|True|'fileName'): File name where to save the figure; + - saveFig (None|True|'fileName'): File name where to save the figure; if set to True uses filename from simConfig (default: None) - showFig (True|False): Whether to show the figure or not (default: True) - Returns figure handles ''' - + from . import sim print('Plotting connectivity matrix...') @@ -2818,9 +2834,9 @@ def plotConn (includePre = ['all'], includePost = ['all'], feature = 'strength', if groupBy == 'pop': popsPre, popsPost = pre, post - from netpyne.support import stackedBarGraph + from netpyne.support import stackedBarGraph SBG = stackedBarGraph.StackedBarGrapher() - + fig = plt.figure(figsize=figSize) ax = fig.add_subplot(111) SBG.stackedBarPlot(ax, connMatrix.transpose(), colorList, xLabels=popsPost, gap = 0.1, scale=False, xlabel='postsynaptic', ylabel = feature) @@ -2839,30 +2855,30 @@ def plotConn (includePre = ['all'], includePost = ['all'], feature = 'strength', if saveData: figData = {'connMatrix': connMatrix, 'feature': feature, 'groupBy': groupBy, 'includePre': includePre, 'includePost': includePost, 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, 'conn') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'conn_'+feature+'.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() - return fig, {'connMatrix': connMatrix, 'feature': feature, 'groupBy': groupBy, 'includePre': includePre, 'includePost': includePost} + return fig, {} ###################################################################################################################################################### ## Plot 2D representation of network cell positions and connections ###################################################################################################################################################### @exception -def plot2Dnet (include = ['allCells'], figSize = (12,12), view = 'xy', showConns = True, popColors = None, - tagsFile = None, saveData = None, saveFig = None, showFig = True): - ''' +def plot2Dnet (include = ['allCells'], figSize = (12,12), view = 'xy', showConns = True, popColors = None, + tagsFile = None, saveData = None, saveFig = None, showFig = True): + ''' Plot 2D representation of network cell positions and connections - include (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Cells to show (default: ['all']) - showConns (True|False): Whether to show connections or not (default: True) @@ -2906,19 +2922,19 @@ def plot2Dnet (include = ['allCells'], figSize = (12,12), view = 'xy', showConns if len(missing) > 0: print("Missing:") print(missing) - return None, None, None + return None, None, None # find pre and post cells if tags: cellGids = getCellsIncludeTags(include, tags, tagsFormat) popLabels = list(set([tags[gid][popIndex] for gid in cellGids])) - + # pop and cell colors popColorsTmp = {popLabel: colorList[ipop%len(colorList)] for ipop,popLabel in enumerate(popLabels)} # dict with color for each pop if popColors: popColorsTmp.update(popColors) popColors = popColorsTmp cellColors = [popColors[tags[gid][popIndex]] for gid in cellGids] - + # cell locations posX = [tags[gid][xIndex] for gid in cellGids] # get all x positions if ycoord == 'y': @@ -2926,14 +2942,14 @@ def plot2Dnet (include = ['allCells'], figSize = (12,12), view = 'xy', showConns elif ycoord == 'z': posY = [tags[gid][zIndex] for gid in cellGids] # get all y positions else: - print('Error loading tags from file') + print('Error loading tags from file') return None else: - cells, cellGids, _ = getCellsInclude(include) + cells, cellGids, _ = getCellsInclude(include) selectedPops = [cell['tags']['pop'] for cell in cells] popLabels = [pop for pop in sim.net.allPops if pop in selectedPops] # preserves original ordering - + # pop and cell colors popColorsTmp = {popLabel: colorList[ipop%len(colorList)] for ipop,popLabel in enumerate(popLabels)} # dict with color for each pop if popColors: popColorsTmp.update(popColors) @@ -2943,24 +2959,24 @@ def plot2Dnet (include = ['allCells'], figSize = (12,12), view = 'xy', showConns # cell locations posX = [cell['tags']['x'] for cell in cells] # get all x positions posY = [cell['tags'][ycoord] for cell in cells] # get all y positions - + plt.scatter(posX, posY, s=60, color = cellColors) # plot cell soma positions if showConns and not tagsFile: for postCell in cells: for con in postCell['conns']: # plot connections between cells if not isinstance(con['preGid'], str) and con['preGid'] in cellGids: - posXpre,posYpre = next(((cell['tags']['x'],cell['tags'][ycoord]) for cell in cells if cell['gid']==con['preGid']), None) - posXpost,posYpost = postCell['tags']['x'], postCell['tags'][ycoord] + posXpre,posYpre = next(((cell['tags']['x'],cell['tags'][ycoord]) for cell in cells if cell['gid']==con['preGid']), None) + posXpost,posYpost = postCell['tags']['x'], postCell['tags'][ycoord] color='red' if con['synMech'] in ['inh', 'GABA', 'GABAA', 'GABAB']: color = 'blue' width = 0.1 #50*con['weight'] plt.plot([posXpre, posXpost], [posYpre, posYpost], color=color, linewidth=width) # plot line from pre to post - + plt.xlabel('x (um)') - plt.ylabel(ycoord+' (um)') - plt.xlim([min(posX)-0.05*max(posX),1.05*max(posX)]) + plt.ylabel(ycoord+' (um)') + plt.xlim([min(posX)-0.05*max(posX),1.05*max(posX)]) plt.ylim([min(posY)-0.05*max(posY),1.05*max(posY)]) fontsiz = 12 @@ -2974,27 +2990,27 @@ def plot2Dnet (include = ['allCells'], figSize = (12,12), view = 'xy', showConns if saveData: figData = {'posX': posX, 'posY': posY, 'posX': cellColors, 'posXpre': posXpre, 'posXpost': posXpost, 'posYpre': posYpre, 'posYpost': posYpost, 'include': include, 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, '2Dnet') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'2Dnet.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() - return fig, {'include': include, 'posX': posX, 'posY': posY, 'posXpre': posXpre, 'posXpost': posXpost, 'posYpre': posYpre, 'posYpost': posYpost} + return fig, {} ###################################################################################################################################################### ## Calculate number of disynaptic connections -###################################################################################################################################################### +###################################################################################################################################################### @exception -def calculateDisynaptic(includePost = ['allCells'], includePre = ['allCells'], includePrePre = ['allCells'], +def calculateDisynaptic(includePost = ['allCells'], includePre = ['allCells'], includePrePre = ['allCells'], tags=None, conns=None, tagsFile=None, connsFile=None): import json @@ -3015,9 +3031,9 @@ def calculateDisynaptic(includePost = ['allCells'], includePre = ['allCells'], i with open(connsFile, 'r') as fileObj: connsTmp = json.load(fileObj)['conns'] conns = {int(k): v for k,v in connsTmp.items()} del connsTmp - + print(' Calculating disynaptic connections...') - # loading from json files + # loading from json files if tags and conns: cellsPreGids = getCellsIncludeTags(includePre, tags) cellsPrePreGids = getCellsIncludeTags(includePrePre, tags) @@ -3034,13 +3050,13 @@ def calculateDisynaptic(includePost = ['allCells'], includePre = ['allCells'], i numDis += 1 else: - if sim.cfg.compactConnFormat: + if sim.cfg.compactConnFormat: if 'preGid' in sim.cfg.compactConnFormat: preGidIndex = sim.cfg.compactConnFormat.index('preGid') # using compact conn format (list) else: print(' Error: cfg.compactConnFormat does not include "preGid"') return -1 - else: + else: preGidIndex = 'preGid' # using long conn format (dict) _, cellsPreGids, _ = getCellsInclude(includePre) @@ -3065,7 +3081,7 @@ def calculateDisynaptic(includePost = ['allCells'], includePre = ['allCells'], i pass print(' time ellapsed (s): ', time() - start) - + return numDis @@ -3074,28 +3090,28 @@ def calculateDisynaptic(includePost = ['allCells'], includePre = ['allCells'], i ###################################################################################################################################################### @exception def nTE(cells1 = [], cells2 = [], spks1 = None, spks2 = None, timeRange = None, binSize = 20, numShuffle = 30): - ''' + ''' Calculate normalized transfer entropy - cells1 (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Subset of cells from which to obtain spike train 1 (default: []) - cells2 (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Subset of cells from which to obtain spike train 1 (default: []) - spks1 (list): Spike train 1; list of spike times; if omitted then obtains spikes from cells1 (default: None) - spks2 (list): Spike train 2; list of spike times; if omitted then obtains spikes from cells2 (default: None) - timeRange ([min, max]): Range of time to calculate nTE in ms (default: [0,cfg.duration]) - - binSize (int): Bin size used to convert spike times into histogram + - binSize (int): Bin size used to convert spike times into histogram - numShuffle (int): Number of times to shuffle spike train 1 to calculate TEshuffled; note: nTE = (TE - TEShuffled)/H(X2F|X2P) - - Returns nTE (float): normalized transfer entropy + - Returns nTE (float): normalized transfer entropy ''' from neuron import h import netpyne from . import sim import os - + root = os.path.dirname(netpyne.__file__) - - if 'nte' not in dir(h): - try: + + if 'nte' not in dir(h): + try: print(' Warning: support/nte.mod not compiled; attempting to compile from %s via "nrnivmodl support"'%(root)) os.system('cd ' + root + '; nrnivmodl support') from neuron import load_mechanisms @@ -3106,7 +3122,7 @@ def nTE(cells1 = [], cells2 = [], spks1 = None, spks2 = None, timeRange = None, return h.load_file(root+'/support/nte.hoc') # nTE code (also requires support/net.mod) - + if not spks1: # if doesnt contain a list of spk times, obtain from cells specified cells, cellGids, netStimPops = getCellsInclude(cells1) numNetStims = 0 @@ -3117,7 +3133,7 @@ def nTE(cells1 = [], cells2 = [], spks1 = None, spks2 = None, timeRange = None, spkts = [spkt for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids] except: spkts = [] - else: + else: spkts = [] # Add NetStim spikes @@ -3143,7 +3159,7 @@ def nTE(cells1 = [], cells2 = [], spks1 = None, spks2 = None, timeRange = None, spkts = [spkt for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids] except: spkts = [] - else: + else: spkts = [] # Add NetStim spikes @@ -3168,9 +3184,9 @@ def nTE(cells1 = [], cells2 = [], spks1 = None, spks2 = None, timeRange = None, inputVec = h.Vector() outputVec = h.Vector() histo1 = np.histogram(spks1, bins = np.arange(timeRange[0], timeRange[1], binSize)) - histoCount1 = histo1[0] + histoCount1 = histo1[0] histo2 = np.histogram(spks2, bins = np.arange(timeRange[0], timeRange[1], binSize)) - histoCount2 = histo2[0] + histoCount2 = histo2[0] inputVec.from_python(histoCount1) outputVec.from_python(histoCount2) @@ -3183,10 +3199,10 @@ def nTE(cells1 = [], cells2 = [], spks1 = None, spks2 = None, timeRange = None, ## Calculate granger causality ###################################################################################################################################################### @exception -def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkTrain1', label2 = 'spkTrain2', timeRange = None, binSize=5, plotFig = True, +def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkTrain1', label2 = 'spkTrain2', timeRange = None, binSize=5, plotFig = True, saveData = None, saveFig = None, showFig = True): - ''' - Calculate and optionally plot Granger Causality + ''' + Calculate and optionally plot Granger Causality - cells1 (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Subset of cells from which to obtain spike train 1 (default: []) - cells2 (['all',|'allCells','allNetStims',|,120,|,'E1'|,('L2', 56)|,('L5',[4,5,6])]): Subset of cells from which to obtain spike train 2 (default: []) - spks1 (list): Spike train 1; list of spike times; if omitted then obtains spikes from cells1 (default: None) @@ -3194,7 +3210,7 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT - label1 (string): Label for spike train 1 to use in plot - label2 (string): Label for spike train 2 to use in plot - timeRange ([min, max]): Range of time to calculate nTE in ms (default: [0,cfg.duration]) - - binSize (int): Bin size used to convert spike times into histogram + - binSize (int): Bin size used to convert spike times into histogram - plotFig (True|False): Whether to plot a figure showing Granger Causality Fx2y and Fy2x - saveData (None|'fileName'): File name where to save the final data used to generate the figure (default: None) - saveFig (None|'fileName'): File name where to save the figure; @@ -3202,14 +3218,14 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT - showFig (True|False): Whether to show the figure or not; if set to True uses filename from simConfig (default: None) - - Returns + - Returns F: list of freqs - Fx2y: causality measure from x to y - Fy2x: causality from y to x - Fxy: instantaneous causality between x and y - fig: Figure handle + Fx2y: causality measure from x to y + Fy2x: causality from y to x + Fxy: instantaneous causality between x and y + fig: Figure handle ''' - + from . import sim import numpy as np from netpyne.support.bsmart import pwcausalr @@ -3224,7 +3240,7 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT spkts = [spkt for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids] except: spkts = [] - else: + else: spkts = [] # Add NetStim spikes @@ -3232,7 +3248,7 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT numNetStims = 0 for netStimPop in netStimPops: if 'stims' in sim.allSimData: - cellStims = [cellStim for cell, cellStim in sim.allSimData['stims'].items() if netStimPop in cellStim] + cellStims = [cellStim for cell,cellStim in sim.allSimData['stims'].items() if netStimPop in cellStim] if len(cellStims) > 0: spktsNew = [spkt for cellStim in cellStims for spkt in cellStim[netStimPop] ] spkts.extend(spktsNew) @@ -3250,7 +3266,7 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT spkts = [spkt for spkgid,spkt in zip(sim.allSimData['spkid'],sim.allSimData['spkt']) if spkgid in cellGids] except: spkts = [] - else: + else: spkts = [] # Add NetStim spikes @@ -3266,22 +3282,23 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT spks2 = list(spkts) + # time range if timeRange is None: if getattr(sim, 'cfg', None): timeRange = [0,sim.cfg.duration] else: timeRange = [0, max(spks1+spks2)] - + histo1 = np.histogram(spks1, bins = np.arange(timeRange[0], timeRange[1], binSize)) - histoCount1 = histo1[0] + histoCount1 = histo1[0] histo2 = np.histogram(spks2, bins = np.arange(timeRange[0], timeRange[1], binSize)) - histoCount2 = histo2[0] + histoCount2 = histo2[0] + + fs = 1000/binSize + F,pp,cohe,Fx2y,Fy2x,Fxy = pwcausalr(np.array([histoCount1, histoCount2]), 1, len(histoCount1), 10, fs, fs/2) - fs = int(1000/binSize) - F,pp,cohe,Fx2y,Fy2x,Fxy = pwcausalr(np.array([histoCount1, histoCount2]), 1, len(histoCount1), 10, fs, int(fs/2)) - # plot granger fig = -1 @@ -3292,23 +3309,23 @@ def granger(cells1 = [], cells2 = [], spks1 = None, spks2 = None, label1 = 'spkT plt.xlabel('Frequency (Hz)') plt.ylabel('Granger Causality') plt.legend() - + # save figure data if saveData: - figData = {'cells1': cells1, 'cells2': cells2, 'spks1': cells1, 'spks2': cells2, 'binSize': binSize, 'Fy2x': Fy2x[0], 'Fx2y': Fx2y[0], + figData = {'cells1': cells1, 'cells2': cells2, 'spks1': cells1, 'spks2': cells2, 'binSize': binSize, 'Fy2x': Fy2x[0], 'Fx2y': Fx2y[0], 'saveData': saveData, 'saveFig': saveFig, 'showFig': showFig} - + _saveFigData(figData, saveData, '2Dnet') - + # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'2Dnet.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() return fig, {'F': F, 'Fx2y': Fx2y[0], 'Fy2x': Fy2x[0], 'Fxy': Fxy[0]} @@ -3330,7 +3347,7 @@ def plotEPSPAmp(include=None, trace=None, start=0, interval=50, number=2, amp='a cells, cellGids, _ = getCellsInclude(include) gidPops = {cell['gid']: cell['tags']['pop'] for cell in cells} - if not trace: + if not trace: print('Error: Missing trace to to plot EPSP amplitudes') return step = sim.cfg.recordStep @@ -3341,9 +3358,9 @@ def plotEPSPAmp(include=None, trace=None, start=0, interval=50, number=2, amp='a vsoma = sim.allSimData[trace]['cell_'+str(gid)] for ipeak in range(number): if polarity == 'exc': - peakAbs = max(vsoma[int(start/step+(ipeak*interval/step)):int(start/step+(ipeak*interval/step)+(interval-1)/step)]) + peakAbs = max(vsoma[int(start/step+(ipeak*interval/step)):int(start/step+(ipeak*interval/step)+(interval-1)/step)]) elif polarity == 'inh': - peakAbs = min(vsoma[int(start/step+(ipeak*interval/step)):int(start/step+(ipeak*interval/step)+(interval-1)/step)]) + peakAbs = min(vsoma[int(start/step+(ipeak*interval/step)):int(start/step+(ipeak*interval/step)+(interval-1)/step)]) peakRel = peakAbs - vsoma[int((start-1)/step)] peaksAbs[ipeak,icell] = peakAbs peaksRel[ipeak,icell] = peakRel @@ -3359,7 +3376,7 @@ def plotEPSPAmp(include=None, trace=None, start=0, interval=50, number=2, amp='a for icell in range(len(cellGids)): peaks[:, icell] = peaksRel[:, icell] / peaksRel[0, icell] ylabel = 'EPSP amplitude ratio' - + xlabel = 'EPSP number' # plot @@ -3377,17 +3394,17 @@ def plotEPSPAmp(include=None, trace=None, start=0, interval=50, number=2, amp='a plt.plot((0, number-1), (1.0, 1.0), ':', color= 'gray') # save figure - if saveFig: + if saveFig: if isinstance(saveFig, str): filename = saveFig else: filename = sim.cfg.filename+'_'+'EPSPamp_'+amp+'.png' plt.savefig(filename) - # show fig + # show fig if showFig: _showFigure() - return peaks, fig + return fig, {'peaks': peaks} ###################################################################################################################################################### ## Plot RxD concentration @@ -3410,7 +3427,7 @@ def plotRxDConcentration(speciesLabel, regionLabel, plane='xy', showFig=True): #ax.add_artist(sb) plt.colorbar(label="$%s^+$ (mM)"%(species.name)) - # show fig + # show fig if showFig: _showFigure() - return fig, {'data': species[region].states3d[:].mean(plane2mean[plane])} + return fig, {} diff --git a/netpyne/batch.py b/netpyne/batch.py index 2d9a8a446..89501868c 100644 --- a/netpyne/batch.py +++ b/netpyne/batch.py @@ -6,28 +6,62 @@ Contributors: salvadordura@gmail.com """ - +import imp +import json +import logging import datetime +from neuron import h +from copy import copy +from netpyne import specs +from netpyne.utils import bashTemplate +from random import Random +from time import sleep, time from itertools import product from subprocess import Popen, PIPE -from time import sleep -import imp -from netpyne import specs -from neuron import h + pc = h.ParallelContext() # use bulletin board master/slave if pc.id()==0: pc.master_works_on_jobs(0) +# ------------------------------------------------------------------------------- # function to run single job using ParallelContext bulletin board (master/slave) +# ------------------------------------------------------------------------------- # func needs to be outside of class -def runJob(script, cfgSavePath, netParamsSavePath): +def runEvolJob(script, cfgSavePath, netParamsSavePath, simDataPath): + import os + print('\nJob in rank id: ',pc.id()) + command = 'nrniv %s simConfig=%s netParams=%s' % (script, cfgSavePath, netParamsSavePath) + + with open(simDataPath+'.run', 'w') as outf, open(simDataPath+'.err', 'w') as errf: + pid = Popen(command.split(' '), stdout=outf, stderr=errf, preexec_fn=os.setsid).pid + with open('./pids.pid', 'a') as file: + file.write(str(pid) + ' ') + +# func needs to be outside of class +def runJob(script, cfgSavePath, netParamsSavePath): print('\nJob in rank id: ',pc.id()) command = 'nrniv %s simConfig=%s netParams=%s' % (script, cfgSavePath, netParamsSavePath) print(command+'\n') proc = Popen(command.split(' '), stdout=PIPE, stderr=PIPE) print(proc.stdout.read()) + +# ------------------------------------------------------------------------------- +# function to create a folder if it does not exist +# ------------------------------------------------------------------------------- +def createFolder(folder): + import os + if not os.path.exists(folder): + try: + os.mkdir(folder) + except OSError: + print(' Could not create %s' %(folder)) + + +# ------------------------------------------------------------------------------- +# function to convert tuples to strings (avoids erro when saving/loading) +# ------------------------------------------------------------------------------- def tupleToStr (obj): #print '\nbefore:', obj if type(obj) == list: @@ -44,9 +78,12 @@ def tupleToStr (obj): return obj +# ------------------------------------------------------------------------------- +# Batch class +# ------------------------------------------------------------------------------- class Batch(object): - def __init__(self, cfgFile='cfg.py', netParamsFile='netParams.py', params=None, groupedParams=None, initCfg={}): + def __init__(self, cfgFile='cfg.py', netParamsFile='netParams.py', params=None, groupedParams=None, initCfg={}, seed=None): self.batchLabel = 'batch_'+str(datetime.date.today()) self.cfgFile = cfgFile self.initCfg = initCfg @@ -54,7 +91,9 @@ def __init__(self, cfgFile='cfg.py', netParamsFile='netParams.py', params=None, self.saveFolder = '/'+self.batchLabel self.method = 'grid' self.runCfg = {} + self.evolCfg = {} self.params = [] + self.seed = seed if params: for k,v in params.items(): self.params.append({'label': k, 'values': v}) @@ -62,21 +101,20 @@ def __init__(self, cfgFile='cfg.py', netParamsFile='netParams.py', params=None, for p in self.params: if p['label'] in groupedParams: p['group'] = True + def save(self, filename): import os from copy import deepcopy basename = os.path.basename(filename) folder = filename.split(basename)[0] ext = basename.split('.')[1] - + # make dir - try: - os.mkdir(folder) - except OSError: - if not os.path.exists(folder): - print(' Could not create', folder) + createFolder(folder) odict = deepcopy(self.__dict__) + if 'evolCfg' in odict: + odict['evolCfg']['fitnessFunc'] = 'removed' dataSave = {'batch': tupleToStr(odict)} if ext == 'json': import json @@ -99,7 +137,60 @@ def setCfgNestedParam(self, paramLabel, paramVal): else: setattr(self.cfg, paramLabel, paramVal) # set simConfig params + + def saveScripts(self): + import os + + # create Folder to save simulation + createFolder(self.saveFolder) + + # save Batch dict as json + targetFile = self.saveFolder+'/'+self.batchLabel+'_batch.json' + self.save(targetFile) + + # copy this batch script to folder + targetFile = self.saveFolder+'/'+self.batchLabel+'_batchScript.py' + os.system('cp ' + os.path.realpath(__file__) + ' ' + targetFile) + + # copy this batch script to folder, netParams and simConfig + #os.system('cp ' + self.netParamsFile + ' ' + self.saveFolder + '/netParams.py') + + netParamsSavePath = self.saveFolder+'/'+self.batchLabel+'_netParams.py' + os.system('cp ' + self.netParamsFile + ' ' + netParamsSavePath) + + os.system('cp ' + os.path.realpath(__file__) + ' ' + self.saveFolder + '/batchScript.py') + + # save initial seed + with open(self.saveFolder + '/_seed.seed', 'w') as seed_file: + if not self.seed: self.seed = int(time()) + seed_file.write(str(self.seed)) + + # import cfg + cfgModuleName = os.path.basename(self.cfgFile).split('.')[0] + cfgModule = imp.load_source(cfgModuleName, self.cfgFile) + + if hasattr(cfgModule, 'cfg'): + self.cfg = cfgModule.cfg + else: + self.cfg = cfgModule.simConfig + + self.cfg.checkErrors = False # avoid error checking during batch + + + def openFiles2SaveStats(self): + stat_file_name = '%s/%s_stats.cvs' %(self.saveFolder, self.batchLabel) + ind_file_name = '%s/%s_stats_indiv.cvs' %(self.saveFolder, self.batchLabel) + individual = open(ind_file_name, 'w') + stats = open(stat_file_name, 'w') + stats.write('#gen pop-size worst best median average std-deviation\n') + individual.write('#gen #ind fitness [candidate]\n') + return stats, individual + + def run(self): + # ------------------------------------------------------------------------------- + # Grid Search optimization + # ------------------------------------------------------------------------------- if self.method in ['grid','list']: # create saveFolder import os,glob @@ -164,7 +255,7 @@ def run(self): indexCombGroups = [(0,)] # if using pc bulletin board, initialize all workers - if self.runCfg.get('type', None) == 'mpi': + if self.runCfg.get('type', None) == 'mpi_bulletin': for iworker in range(int(pc.nhost())): pc.runworker() @@ -335,16 +426,378 @@ def run(self): pc.submit(runJob, self.runCfg.get('script', 'init.py'), cfgSavePath, netParamsSavePath) sleep(1) # avoid saturating scheduler + print("-"*80) + print(" Finished submitting jobs for grid parameter exploration ") + print("-"*80) + + + # ------------------------------------------------------------------------------- + # Evolutionary optimization + # ------------------------------------------------------------------------------- + elif self.method == 'evol': + import sys + import inspyred.ec as EC + + # ------------------------------------------------------------------------------- + # Evolutionary optimization: Parallel evaluation + # ------------------------------------------------------------------------------- + def evaluator(candidates, args): + import os + import signal + global ngen + ngen += 1 + total_jobs = 0 + + # options slurm, mpi + type = args.get('type', 'mpi_direct') + + # paths to required scripts + script = args.get('script', 'init.py') + netParamsSavePath = args.get('netParamsSavePath') + genFolderPath = self.saveFolder + '/gen_' + str(ngen) + + # mpi command setup + nodes = args.get('nodes', 1) + paramLabels = args.get('paramLabels', []) + coresPerNode = args.get('coresPerNode', 1) + mpiCommand = args.get('mpiCommand', 'ibrun') + numproc = nodes*coresPerNode + + # slurm setup + custom = args.get('custom', '') + folder = args.get('folder', '.') + email = args.get('email', 'a@b.c') + walltime = args.get('walltime', '00:01:00') + reservation = args.get('reservation', None) + allocation = args.get('allocation', 'csd403') # NSG account + + # fitness function + fitnessFunc = args.get('fitnessFunc') + fitnessFuncArgs = args.get('fitnessFuncArgs') + defaultFitness = args.get('defaultFitness') + + # read params or set defaults + sleepInterval = args.get('sleepInterval', 0.2) + + # create folder if it does not exist + createFolder(genFolderPath) + + # remember pids and jobids in a list + pids = [] + jobids = {} + + # create a job for each candidate + for candidate_index, candidate in enumerate(candidates): + # required for slurm + sleep(sleepInterval) + + # name and path + jobName = "gen_" + str(ngen) + "_cand_" + str(candidate_index) + jobPath = genFolderPath + '/' + jobName + + # modify cfg instance with candidate values + for label, value in zip(paramLabels, candidate): + self.setCfgNestedParam(label, value) + print('set %s=%s' % (label, value)) + + #self.setCfgNestedParam("filename", jobPath) + self.cfg.simLabel = jobName + self.cfg.saveFolder = genFolderPath - # wait for pc bulletin board jobs to finish - try: - while pc.working(): - sleep(1) - #pc.done() - except: - pass - - sleep(10) # give time for last job to get on queue + # save cfg instance to file + cfgSavePath = jobPath + '_cfg.json' + self.cfg.save(cfgSavePath) + + + if type=='mpi_bulletin': + # ---------------------------------------------------------------------- + # MPI master-slaves + # ---------------------------------------------------------------------- + pc.submit(runEvolJob, script, cfgSavePath, netParamsSavePath, jobPath) + print('-'*80) + else: + # ---------------------------------------------------------------------- + # MPI job commnand + # ---------------------------------------------------------------------- + command = '%s -np %d nrniv -python -mpi %s simConfig=%s netParams=%s ' % (mpiCommand, numproc, script, cfgSavePath, netParamsSavePath) + + # ---------------------------------------------------------------------- + # run on local machine with cores + # ---------------------------------------------------------------------- + if type=='mpi_direct': + executer = '/bin/bash' + jobString = bashTemplate('mpi_direct') %(custom, folder, command) + + # ---------------------------------------------------------------------- + # run on HPC through slurm + # ---------------------------------------------------------------------- + elif type=='hpc_slurm': + executer = 'sbatch' + res = '#SBATCH --res=%s' % (reservation) if reservation else '' + jobString = bashTemplate('hpc_slurm') % (jobName, allocation, walltime, nodes, coresPerNode, jobPath, jobPath, email, res, custom, folder, command) + + # ---------------------------------------------------------------------- + # run on HPC through PBS + # ---------------------------------------------------------------------- + elif type=='hpc_torque': + executer = 'qsub' + queueName = args.get('queueName', 'default') + nodesppn = 'nodes=%d:ppn=%d' % (nodes, coresPerNode) + jobString = bashTemplate('hpc_torque') % (jobName, walltime, queueName, nodesppn, jobPath, jobPath, custom, command) + + # ---------------------------------------------------------------------- + # save job and run + # ---------------------------------------------------------------------- + print('Submitting job ', jobName) + print(jobString) + print('-'*80) + # save file + batchfile = '%s.sbatch' % (jobPath) + with open(batchfile, 'w') as text_file: + text_file.write("%s" % jobString) + + #with open(jobPath+'.run', 'a+') as outf, open(jobPath+'.err', 'w') as errf: + with open(jobPath+'.jobid', 'w') as outf, open(jobPath+'.err', 'w') as errf: + pids.append(Popen([executer, batchfile], stdout=outf, stderr=errf, preexec_fn=os.setsid).pid) + #proc = Popen(command.split([executer, batchfile]), stdout=PIPE, stderr=PIPE) + sleep(0.1) + #read = proc.stdout.read() + with open(jobPath+'.jobid', 'r') as outf: + read=outf.readline() + print(read) + if len(read) > 0: + jobid = int(read.split()[-1]) + jobids[candidate_index] = jobid + print('jobids', jobids) + total_jobs += 1 + sleep(0.1) + + + # ---------------------------------------------------------------------- + # gather data and compute fitness + # ---------------------------------------------------------------------- + if type == 'mpi_bulletin': + # wait for pc bulletin board jobs to finish + try: + while pc.working(): + sleep(1) + #pc.done() + except: + pass + + num_iters = 0 + jobs_completed = 0 + fitness = [None for cand in candidates] + # print outfilestem + print("Waiting for jobs from generation %d/%d ..." %(ngen, args.get('max_generations'))) + # print "PID's: %r" %(pids) + # start fitness calculation + while jobs_completed < total_jobs: + unfinished = [i for i, x in enumerate(fitness) if x is None ] + for candidate_index in unfinished: + try: # load simData and evaluate fitness + jobNamePath = genFolderPath + "/gen_" + str(ngen) + "_cand_" + str(candidate_index) + if os.path.isfile(jobNamePath+'.json'): + with open('%s.json'% (jobNamePath)) as file: + simData = json.load(file)['simData'] + fitness[candidate_index] = fitnessFunc(simData, **fitnessFuncArgs) + jobs_completed += 1 + print(' Candidate %d fitness = %.1f' % (candidate_index, fitness[candidate_index])) + except Exception as e: + # print + err = "There was an exception evaluating candidate %d:"%(candidate_index) + print(("%s \n %s"%(err,e))) + #pass + #print 'Error evaluating fitness of candidate %d'%(candidate_index) + num_iters += 1 + print('completed: %d' %(jobs_completed)) + if num_iters >= args.get('maxiter_wait', 5000): + print("Max iterations reached, the %d unfinished jobs will be canceled and set to default fitness" % (len(unfinished))) + for canditade_index in unfinished: + fitness[canditade_index] = defaultFitness + jobs_completed += 1 + if 'scancelUser' in kwargs: + os.system('scancel -u %s'%(kwargs['scancelUser'])) + else: + os.system('scancel %d'%(jobids[candidate_index])) # terminate unfinished job (resubmitted jobs not terminated!) + sleep(args.get('time_sleep', 1)) + + # kill all processes + if type=='mpi_bulletin': + try: + with open("./pids.pid", 'r') as file: # read pids for mpi_bulletin + pids = [int(i) for i in file.read().split(' ')[:-1]] + + with open("./pids.pid", 'w') as file: # delete content + pass + for pid in pids: + try: + os.killpg(os.getpgid(pid), signal.SIGTERM) + except: + pass + except: + pass + # don't want to to this for hpcs since jobs are running on compute nodes not master + # else: + # try: + # for pid in pids: os.killpg(os.getpgid(pid), signal.SIGTERM) + # except: + # pass + # return + print("-"*80) + print(" Completed a generation ") + print("-"*80) + return fitness + + # ------------------------------------------------------------------------------- + # Evolutionary optimization: Generation of first population candidates + # ------------------------------------------------------------------------------- + def generator(random, args): + # generate initial values for candidates + return [random.uniform(l, u) for l, u in zip(args.get('lower_bound'), args.get('upper_bound'))] + # ------------------------------------------------------------------------------- + # Mutator + # ------------------------------------------------------------------------------- + @EC.variators.mutator + def nonuniform_bounds_mutation(random, candidate, args): + """Return the mutants produced by nonuniform mutation on the candidates. + .. Arguments: + random -- the random number generator object + candidate -- the candidate solution + args -- a dictionary of keyword arguments + Required keyword arguments in args: + Optional keyword arguments in args: + - *mutation_strength* -- the strength of the mutation, where higher + values correspond to greater variation (default 1) + """ + lower_bound = args.get('lower_bound') + upper_bound = args.get('upper_bound') + strength = args.setdefault('mutation_strength', 1) + mutant = copy(candidate) + for i, (c, lo, hi) in enumerate(zip(candidate, lower_bound, upper_bound)): + if random.random() <= 0.5: + new_value = c + (hi - c) * (1.0 - random.random() ** strength) + else: + new_value = c - (c - lo) * (1.0 - random.random() ** strength) + mutant[i] = new_value + + return mutant + # ------------------------------------------------------------------------------- + # Evolutionary optimization: Main code + # ------------------------------------------------------------------------------- + import os + # create main sim directory and save scripts + self.saveScripts() + + global ngen + ngen = -1 + + # log for simulation + logger = logging.getLogger('inspyred.ec') + logger.setLevel(logging.DEBUG) + file_handler = logging.FileHandler(self.saveFolder+'/inspyred.log', mode='a') + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + # create randomizer instance + rand = Random() + rand.seed(self.seed) + + # create file handlers for observers + stats_file, ind_stats_file = self.openFiles2SaveStats() + + # gather **kwargs + kwargs = {'cfg': self.cfg} + kwargs['num_inputs'] = len(self.params) + kwargs['paramLabels'] = [x['label'] for x in self.params] + kwargs['lower_bound'] = [x['values'][0] for x in self.params] + kwargs['upper_bound'] = [x['values'][1] for x in self.params] + kwargs['statistics_file'] = stats_file + kwargs['individuals_file'] = ind_stats_file + kwargs['netParamsSavePath'] = self.saveFolder+'/'+self.batchLabel+'_netParams.py' + + for key, value in self.evolCfg.items(): + kwargs[key] = value + if not 'maximize' in kwargs: kwargs['maximize'] = False + + for key, value in self.runCfg.items(): + kwargs[key] = value + + # if using pc bulletin board, initialize all workers + if self.runCfg.get('type', None) == 'mpi_bulletin': + for iworker in range(int(pc.nhost())): + pc.runworker() + #################################################################### + # Evolution strategy + #################################################################### + # Custom algorithm based on Krichmar's params + if self.evolCfg['evolAlgorithm'] == 'custom': + ea = EC.EvolutionaryComputation(rand) + ea.selector = EC.selectors.tournament_selection + ea.variator = [EC.variators.uniform_crossover, nonuniform_bounds_mutation] + ea.replacer = EC.replacers.generational_replacement + if not 'tournament_size' in kwargs: kwargs['tournament_size'] = 2 + if not 'num_selected' in kwargs: kwargs['num_selected'] = kwargs['pop_size'] + + # Genetic + elif self.evolCfg['evolAlgorithm'] == 'genetic': + ea = EC.GA(rand) + + # Evolution Strategy + elif self.evolCfg['evolAlgorithm'] == 'evolutionStrategy': + ea = EC.ES(rand) + + # Simulated Annealing + elif self.evolCfg['evolAlgorithm'] == 'simulatedAnnealing': + ea = EC.SA(rand) + + # Differential Evolution + elif self.evolCfg['evolAlgorithm'] == 'diffEvolution': + ea = EC.DEA(rand) + + # Estimation of Distribution + elif self.evolCfg['evolAlgorithm'] == 'estimationDist': + ea = EC.EDA(rand) + + # Particle Swarm optimization + elif self.evolCfg['evolAlgorithm'] == 'particleSwarm': + from inspyred import swarm + ea = swarm.PSO(rand) + ea.topology = swarm.topologies.ring_topology + + # Ant colony optimization (requires components) + elif self.evolCfg['evolAlgorithm'] == 'antColony': + from inspyred import swarm + if not 'components' in kwargs: raise ValueError("%s requires components" %(self.evolCfg['evolAlgorithm'])) + ea = swarm.ACS(rand, self.evolCfg['components']) + ea.topology = swarm.topologies.ring_topology + + else: + raise ValueError("%s is not a valid strategy" %(self.evolCfg['evolAlgorithm'])) + #################################################################### + ea.terminator = EC.terminators.generation_termination + ea.observer = [EC.observers.stats_observer, EC.observers.file_observer] + # ------------------------------------------------------------------------------- + # Run algorithm + # ------------------------------------------------------------------------------- + final_pop = ea.evolve(generator=generator, + evaluator=evaluator, + bounder=EC.Bounder(kwargs['lower_bound'],kwargs['upper_bound']), + logger=logger, + **kwargs) + + # close file + stats_file.close() + ind_stats_file.close() + + # print best and finish + print(('Best Solution: \n{0}'.format(str(max(final_pop))))) + print("-"*80) + print(" Completed evolutionary algorithm parameter optimization ") + print("-"*80) + sys.exit() diff --git a/netpyne/cell.py b/netpyne/cell.py index 7303fc022..60ebc7d07 100644 --- a/netpyne/cell.py +++ b/netpyne/cell.py @@ -238,7 +238,7 @@ def __getstate__ (self): from . import sim odict = self.__dict__.copy() # copy the dict since we change it - odict = sim.copyReplaceItemObj(odict, keystart='h', newval=None) # replace h objects with None so can be pickled + odict = sim.copyRemoveItemObj(odict, keystart='h') #, newval=None) # replace h objects with None so can be pickled odict = sim.copyReplaceItemObj(odict, keystart='NeuroML', newval='---Removed_NeuroML_obj---') # replace NeuroML objects with str so can be pickled return odict @@ -1180,7 +1180,7 @@ def _setConnWeights (self, params, netStimParams, secLabels): if netStimParams: scaleFactor = sim.net.params.scaleConnWeightNetStims - elif sim.net.params.scaleConnWeightModels.get(self.tags['cellModel'], None) is not None: + elif isinstance(sim.net.params.scaleConnWeightModels, dict) and sim.net.params.scaleConnWeightModels.get(self.tags['cellModel'], None) is not None: scaleFactor = sim.net.params.scaleConnWeightModels[self.tags['cellModel']] # use scale factor specific for this cell model else: scaleFactor = sim.net.params.scaleConnWeight # use global scale factor @@ -1613,7 +1613,7 @@ def _setConnWeights (self, params, netStimParams): if netStimParams: scaleFactor = sim.net.params.scaleConnWeightNetStims - elif sim.net.params.scaleConnWeightModels.get(self.tags['cellModel'], None) is not None: + elif isinstance(sim.net.params.scaleConnWeightModels, dict) and sim.net.params.scaleConnWeightModels.get(self.tags['cellModel'], None) is not None: scaleFactor = sim.net.params.scaleConnWeightModels[self.tags['cellModel']] # use scale factor specific for this cell model else: scaleFactor = sim.net.params.scaleConnWeight # use global scale factor diff --git a/netpyne/pop.py b/netpyne/pop.py index 05929431d..bfc6a8c4a 100644 --- a/netpyne/pop.py +++ b/netpyne/pop.py @@ -363,7 +363,8 @@ def __getstate__ (self): ''' Removes non-picklable h objects so can be pickled and sent via py_alltoall''' odict = self.__dict__.copy() # copy the dict since we change it odict = sim.replaceFuncObj(odict) # replace h objects with None so can be pickled - odict['cellModelClass'] = str(odict['cellModelClass']) + #odict['cellModelClass'] = str(odict['cellModelClass']) + del odict['cellModelClass'] del odict['rand'] return odict diff --git a/netpyne/simFuncs.py b/netpyne/simFuncs.py index 2dfbe59b5..1109c1db3 100644 --- a/netpyne/simFuncs.py +++ b/netpyne/simFuncs.py @@ -8,7 +8,7 @@ __all__.extend(['initialize', 'setNet', 'setNetParams', 'setSimCfg', 'createParallelContext', 'setupRecording', 'setupRecordLFP', 'calculateLFP', 'clearAll', 'setGlobals']) # init and setup __all__.extend(['preRun', 'runSim', 'runSimWithIntervalFunc', '_gatherAllCellTags', '_gatherAllCellConnPreGids', '_gatherCells', 'gatherData']) # run and gather __all__.extend(['saveData', 'loadSimCfg', 'loadNetParams', 'loadNet', 'loadSimData', 'loadAll', 'ijsonLoad', 'compactConnFormat', 'distributedSaveHDF5', 'loadHDF5']) # saving and loading -__all__.extend(['popAvgRates', 'id32', 'copyReplaceItemObj', 'clearObj', 'replaceItemObj', 'replaceNoneObj', 'replaceFuncObj', 'replaceDictODict', +__all__.extend(['popAvgRates', 'id32', 'copyReplaceItemObj', 'copyRemoveItemObj', 'clearObj', 'replaceItemObj', 'replaceNoneObj', 'replaceFuncObj', 'replaceDictODict', 'readCmdLineArgs', 'getCellsList', 'cellByGid','timing', 'version', 'gitChangeset', 'loadBalance','_init_stim_randomizer', 'decimalToFloat', 'unique', 'rename']) # misc/utilities @@ -277,15 +277,16 @@ def loadSimCfg (filename, data=None, setLoaded=True): ############################################################################### def loadSimData (filename, data=None): from . import sim + if not data: data = _loadFile(filename) print('Loading simData...') if 'simData' in data: sim.allSimData = data['simData'] - print('done') + print('Done') else: - print(6) print((' simData not found in file %s'%(filename))) + pass @@ -294,8 +295,8 @@ def loadSimData (filename, data=None): ############################################################################### def loadAll (filename, data=None, instantiate=True, createNEURONObj=True): from . import sim + if not data: data = _loadFile(filename) - loadSimCfg(filename, data=data) sim.cfg.createNEURONObj = createNEURONObj # set based on argument loadNetParams(filename, data=data) @@ -419,7 +420,7 @@ def _byteify(data, ignore_dicts = False): import json print(('Loading file %s ... ' % (filename))) with open(filename, 'r') as fileObj: - data = json.load(fileObj) + data = json.load(fileObj, object_hook=_byteify) # load mat file elif ext == 'mat': @@ -487,6 +488,7 @@ def clearAll (): sim.clearObj([cell.__dict__ if hasattr(cell, '__dict__') else cell for cell in sim.net.cells]) if 'stims' in list(sim.simData.keys()): sim.clearObj([stim for stim in sim.simData['stims']]) + for key in list(sim.simData.keys()): del sim.simData[key] for c in sim.net.cells: del c for p in sim.net.pops: del p @@ -499,6 +501,7 @@ def clearAll (): sim.clearObj([cell.__dict__ if hasattr(cell, '__dict__') else cell for cell in sim.net.allCells]) if hasattr(sim, 'allSimData') and 'stims' in list(sim.allSimData.keys()): sim.clearObj([stim for stim in sim.allSimData['stims']]) + for key in list(sim.allSimData.keys()): del sim.allSimData[key] for c in sim.net.allCells: del c for p in sim.net.allPops: del p @@ -566,6 +569,40 @@ def copyReplaceItemObj (obj, keystart, newval, objCopy='ROOT'): return objCopy +############################################################################### +### Remove item with specific key from dict or list (used to remove h objects) +############################################################################### +def copyRemoveItemObj (obj, keystart, objCopy='ROOT'): + if type(obj) == list: + if objCopy=='ROOT': + objCopy = [] + for item in obj: + if isinstance(item, list): + objCopy.append([]) + copyRemoveItemObj(item, keystart, objCopy[-1]) + elif isinstance(item, (dict, Dict)): + objCopy.append({}) + copyRemoveItemObj(item, keystart, objCopy[-1]) + else: + objCopy.append(item) + + elif isinstance(obj, (dict, Dict)): + if objCopy == 'ROOT': + objCopy = Dict() + for key,val in obj.items(): + if type(val) in [list]: + objCopy[key] = [] + copyRemoveItemObj(val, keystart, objCopy[key]) + elif isinstance(val, (dict, Dict)): + objCopy[key] = {} + copyRemoveItemObj(val, keystart, objCopy[key]) + elif key.startswith(keystart): + objCopy.pop(key, None) + else: + objCopy[key] = val + return objCopy + + ############################################################################### ### Rename objects ############################################################################### @@ -580,7 +617,6 @@ def rename (obj, old, new, label=None): return False - ############################################################################### ### Recursively remove items of an object (used to avoid mem leaks) ############################################################################### @@ -779,7 +815,6 @@ def readCmdLineArgs (simConfigDefault='cfg.py', netParamsDefault='netParams.py') if len(sys.argv) > 1: print('\nReading command line arguments using syntax: python file.py [simConfig=filepath] [netParams=filepath]') - cfgPath = None netParamsPath = None @@ -1421,7 +1456,8 @@ def gatherData (gatherLFP = True): if 'runTime' in sim.timingData: print((' Spikes: %i (%0.2f Hz)' % (sim.totalSpikes, sim.firingRate))) if sim.cfg.printPopAvgRates and not sim.cfg.gatherOnlySimData: - sim.allSimData['popRates'] = sim.popAvgRates() + trange = sim.cfg.printPopAvgRates if isinstance(sim.cfg.printPopAvgRates,list) else None + sim.allSimData['popRates'] = sim.popAvgRates(trange=trange) print((' Simulated time: %0.1f s; %i workers' % (sim.cfg.duration/1e3, sim.nhosts))) print((' Run time: %0.2f s' % (sim.timingData['runTime']))) @@ -1533,7 +1569,7 @@ def distributedSaveHDF5(): sim.compactConnFormat() conns = [[cell.gid]+conn for cell in sim.net.cells for conn in cell.conns] - conns = sim.copyReplaceItemObj(conns, keystart='h', newval=[]) + conns = sim.copyRemoveItemObj(conns, keystart='h', newval=[]) connFormat = ['postGid']+sim.cfg.compactConnFormat with h5py.File(sim.cfg.filename+'.h5', 'w') as hf: hf.create_dataset('conns', data = conns) @@ -1562,13 +1598,15 @@ def loadHDF5(filename): ############################################################################### ### Save data ############################################################################### -def saveData (include = None): +def saveData (include = None, filename = None): from . import sim if sim.rank == 0 and not getattr(sim.net, 'allCells', None): needGather = True else: needGather = False if needGather: gatherData() + if filename: sim.cfg.filename = filename + if sim.rank == 0: timing('start', 'saveTime') import os @@ -1610,7 +1648,9 @@ def saveData (include = None): dataSave['netpyne_changeset'] = sim.gitChangeset(show=False) if getattr(sim.net.params, 'version', None): dataSave['netParams_version'] = sim.net.params.version - if 'netParams' in include: net['params'] = replaceFuncObj(sim.net.params.__dict__) + if 'netParams' in include: + sim.net.params.__dict__.pop('_labelid', None) + net['params'] = replaceFuncObj(sim.net.params.__dict__) if 'net' in include: include.extend(['netPops', 'netCells']) if 'netCells' in include: net['cells'] = sim.net.allCells if 'netPops' in include: net['pops'] = sim.net.allPops @@ -1625,56 +1665,58 @@ def saveData (include = None): if dataSave: if sim.cfg.timestampFilename: timestamp = time() - timestampStr = datetime.fromtimestamp(timestamp).strftime('%Y%m%d_%H%M%S') - sim.cfg.filename = sim.cfg.filename+'-'+timestampStr - + timestampStr = '-' + datetime.fromtimestamp(timestamp).strftime('%Y%m%d_%H%M%S') + else: + timestampStr = '' + + filePath = sim.cfg.filename + timestampStr # Save to pickle file if sim.cfg.savePickle: import pickle dataSave = replaceDictODict(dataSave) - print(('Saving output as %s ... ' % (sim.cfg.filename+'.pkl'))) - with open(sim.cfg.filename+'.pkl', 'wb') as fileObj: + print(('Saving output as %s ... ' % (file_path+'.pkl'))) + with open(file_path+'.pkl', 'wb') as fileObj: pickle.dump(dataSave, fileObj) print('Finished saving!') # Save to dpk file if sim.cfg.saveDpk: import gzip - print(('Saving output as %s ... ' % (sim.cfg.filename+'.dpk'))) - fn=sim.cfg.filename #.split('.') - gzip.open(fn, 'wb').write(pk.dumps(dataSave)) # write compressed string + print(('Saving output as %s ... ' % (filePath+'.dpk'))) + #fn=filePath #.split('.') + gzip.open(filePath, 'wb').write(pk.dumps(dataSave)) # write compressed string print('Finished saving!') # Save to json file if sim.cfg.saveJson: import json #dataSave = replaceDictODict(dataSave) # not required since json saves as dict - print(('Saving output as %s ... ' % (sim.cfg.filename+'.json '))) - with open(sim.cfg.filename+'.json', 'w') as fileObj: + print(('Saving output as %s ... ' % (filePath+'.json '))) + with open(filePath+'.json', 'w') as fileObj: json.dump(dataSave, fileObj) print('Finished saving!') # Save to mat file if sim.cfg.saveMat: from scipy.io import savemat - print(('Saving output as %s ... ' % (sim.cfg.filename+'.mat'))) - savemat(sim.cfg.filename+'.mat', tupleToList(replaceNoneObj(dataSave))) # replace None and {} with [] so can save in .mat format + print(('Saving output as %s ... ' % (filePath+'.mat'))) + savemat(filePath+'.mat', tupleToList(replaceNoneObj(dataSave))) # replace None and {} with [] so can save in .mat format print('Finished saving!') # Save to HDF5 file (uses very inefficient hdf5storage module which supports dicts) if sim.cfg.saveHDF5: dataSaveUTF8 = _dict2utf8(replaceNoneObj(dataSave)) # replace None and {} with [], and convert to utf import hdf5storage - print(('Saving output as %s... ' % (sim.cfg.filename+'.hdf5'))) - hdf5storage.writes(dataSaveUTF8, filename=sim.cfg.filename+'.hdf5') + print(('Saving output as %s... ' % (filePath+'.hdf5'))) + hdf5storage.writes(dataSaveUTF8, filename=filePath+'.hdf5') print('Finished saving!') # Save to CSV file (currently only saves spikes) if sim.cfg.saveCSV: if 'simData' in dataSave: import csv - print(('Saving output as %s ... ' % (sim.cfg.filename+'.csv'))) - writer = csv.writer(open(sim.cfg.filename+'.csv', 'wb')) + print(('Saving output as %s ... ' % (filePath+'.csv'))) + writer = csv.writer(open(filePath+'.csv', 'wb')) for dic in dataSave['simData']: for values in dic: writer.writerow(values) @@ -1710,7 +1752,7 @@ def saveData (include = None): # return full path import os - return os.getcwd()+'/'+sim.cfg.filename + return os.getcwd() + '/' + filePath else: print('Nothing to save') diff --git a/netpyne/specs.py b/netpyne/specs.py index c6808dd5d..70d995c41 100644 --- a/netpyne/specs.py +++ b/netpyne/specs.py @@ -424,7 +424,7 @@ def __init__(self, netParamsDict=None): ## General connectivity parameters self.scaleConnWeight = 1 # Connection weight scale factor (NetStims not included) self.scaleConnWeightNetStims = 1 # Connection weight scale factor for NetStims - self.scaleConnWeightModels = {} # Connection weight scale factor for each cell model eg. {'Izhi2007': 0.1, 'Friesen': 0.02} + self.scaleConnWeightModels = False # Connection weight scale factor for each cell model eg. {'Izhi2007': 0.1, 'Friesen': 0.02} self.defaultWeight = 1 # default connection weight self.defaultDelay = 1 # default connection delay (ms) self.defaultThreshold = 10 # default Netcon threshold (mV) diff --git a/netpyne/utils.py b/netpyne/utils.py index 9b841afc1..6e8300bb3 100644 --- a/netpyne/utils.py +++ b/netpyne/utils.py @@ -5,7 +5,7 @@ Contributors: salvador dura@gmail.com """ -import os, sys +import os, sys, signal from numbers import Number from neuron import h import importlib @@ -502,7 +502,7 @@ def importConnFromExcel (fileName, sheetName): def ValidateFunction(strFunc, netParamsVars): ''' returns True if "strFunc" can be evaluated''' - + from math import exp, log, sqrt, int, sin, cos, tan, asin, acos, atan, sinh, cosh, tangh, pi, e rand = h.Random() stringFuncRandMethods = ['binomial', 'discunif', 'erlang', 'geometric', 'hypergeo', 'lognormal', 'negexp', 'normal', 'poisson', 'uniform', 'weibull'] @@ -530,3 +530,118 @@ def ValidateFunction(strFunc, netParamsVars): return True except: return False + +def createScript(fname, netParams, simConfig): + import sys + import json + from netpyne import specs + + def replace(string): + # convert bools and null from json to python + return string.replace('true', 'True').replace('false', 'False').replace('null', '""') + + def remove(dictionary): + # remove reserved keys such as __str__, __dict__ + if isinstance(dictionary, dict): + for key, value in list(dictionary.items()): + if key.startswith('__'): + dictionary.pop(key) + else: + remove(value) + + def addAttrToScript(attr, value, obj_name, class_instance, file): + # write line of netpyne code if is different from default value + if not hasattr(class_instance, attr) or value!=getattr(class_instance, attr): + file.write(obj_name + '.' + attr + ' = ' + replace(json.dumps(value, indent=4)) + '\n') + + def header(title, spacer='-'): + # writes a header for the section + return '\n# ' + title.upper() + ' ' + spacer*(77-len(title)) + '\n' + + if isinstance(netParams, specs.NetParams): + # convert netpyne.specs.netParams class to dict class + netParams = netParams.todict() + if isinstance(simConfig, specs.SimConfig): + simConfig = simConfig.todict() + + # remove reserved keys like __str__, __dict__ + remove(netParams) + remove(simConfig) + + # network parameters + params = ['popParams' , 'cellParams', 'synMechParams'] + params += ['connParams', 'stimSourceParams', 'stimTargetParams'] + + try : + with open(fname if fname.endswith('.py') else fname+'.py', 'w') as file: + file.write('from netpyne import specs, sim\n') + file.write(header('documentation')) + file.write("''' Please visit: https://www.netpyne.org '''\n") + file.write("# Python script automatically generated by NetPyNE from netParams and simConfig objects\n") + file.write(header('script', spacer='=')) + file.write('netParams = specs.NetParams()\n') + file.write('simConfig = specs.SimConfig()\n') + + file.write(header('single valued attributes')) + for key, value in list(netParams.items()): + if key not in params: + addAttrToScript(key, value, 'netParams', specs.NetParams(), file) + + file.write(header('network attributes')) + for param in params: + for key, value in list(netParams[param].items()): + file.write("netParams." + param + "['" + key + "'] = " + replace(json.dumps(value, indent=4))+ '\n') + + file.write(header('network configuration')) + for key, value in list(simConfig.items()): + addAttrToScript(key, value, 'simConfig', specs.SimConfig(), file) + + file.write(header('create simulate analyze network')) + file.write('sim.createSimulateAnalyze(netParams=netParams, simConfig=simConfig)\n') + file.write(header('end script', spacer='=')) + + print(("script saved on " + fname)) + + except: + print(('error saving file: %s' %(sys.exc_info()[1]))) + +def bashTemplate(template): + ''' return the bash commands required by template for batch simulation''' + + if template=='mpi_direct': + return """#!/bin/bash +%s +cd %s +%s + """ + elif template=='hpc_slurm': + return """#!/bin/bash +#SBATCH --job-name=%s +#SBATCH -A %s +#SBATCH -t %s +#SBATCH --nodes=%d +#SBATCH --ntasks-per-node=%d +#SBATCH -o %s.run +#SBATCH -e %s.err +#SBATCH --mail-user=%s +#SBATCH --mail-type=end +%s +%s +source ~/.bashrc +cd %s +%s +wait + """ + elif template=='hpc_torque': + return """#!/bin/bash +#PBS -N %s +#PBS -l walltime=%s +#PBS -q %s +#PBS -l %s +#PBS -o %s.run +#PBS -e %s.err +%s +cd $PBS_O_WORKDIR +echo $PBS_O_WORKDIR +%s + """ \ No newline at end of file diff --git a/py2to3.py b/py2to3.py index 4128f7b16..1d38a6e9f 100644 --- a/py2to3.py +++ b/py2to3.py @@ -2,7 +2,7 @@ from subprocess import call py2_root = '../netpyne_temp' -py2_branch = 'master' +py2_branch = 'development' folders = ['netpyne', 'doc', 'examples'] files = ['CHANGES.md', 'README.md', 'sdnotes.org', '.gitignore'] diff --git a/sdnotes.org b/sdnotes.org index 2f7a35679..5a8182778 100644 --- a/sdnotes.org +++ b/sdnotes.org @@ -7614,7 +7614,7 @@ def genLFP (lrec,lx,ly,lz,elx,ely,elz): vlfp.mul(1.0/(4.0*pi*sigma)) return vlfp -* 18dec02 Implementing RxD +* 17Dec02 Implementing RxD ** celParams data model (inside cellParams) - cellParams['rule1'] (dict) -- ['rxd'] (dict)