From 218dbaa10da00aa17f527ed31ef5e4cc0f8f0246 Mon Sep 17 00:00:00 2001 From: Peter C Petersen Date: Wed, 7 Mar 2018 18:27:39 -0500 Subject: [PATCH 01/35] Create README.md --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..b09873e --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# KilosortWrapper +Allows you to load Kilosort from a .xml and a .dat file compatible with Neurosuite + +Settings + + +Features +Skip channels: To skip dead channels, select the skip function in Neuroscope or NDManager +Define probe layouts: The wrapper now supports staggered probes and poly 3 and poly 5 probe layouts... + From 21bdbce476107c5b22e8cfe39747dc6dd72d30af Mon Sep 17 00:00:00 2001 From: Peter C Petersen Date: Wed, 7 Mar 2018 18:33:54 -0500 Subject: [PATCH 02/35] Update README.md --- README.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b09873e..bc4d8fc 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,16 @@ # KilosortWrapper Allows you to load Kilosort from a .xml and a .dat file compatible with Neurosuite -Settings +## Settings +Settings are defined in the StandardSettings file - -Features +## Features Skip channels: To skip dead channels, select the skip function in Neuroscope or NDManager -Define probe layouts: The wrapper now supports staggered probes and poly 3 and poly 5 probe layouts... +Define probe layouts: The wrapper now supports probes with staggered, poly 3 and poly 5 probe layouts... +Allows you to save the output from Kilosort to a sub directory. + +## Outputs +The Kilosort wrapper allows you to save the output in Neurosuite compatible files or for Phy. +### Phy +Saved a channel groups file with information about which shanks the channels are asigned to. From f9ab1d2610380f94b64e170ab71df4e827b7df12 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Fri, 9 Mar 2018 17:30:33 -0500 Subject: [PATCH 03/35] Support for multiple configuration files. SSD drive defined in the wrapper instead of in the config file A new folder for configuration files now exist and you can specificy which config file to use with a third input to the KilosortWrapper --- .../KilosortConfiguration_Omid.m | 8 +- KiloSortWrapper.m | 43 ++++++++-- KilosortConfiguration.m | 82 +++++++++++++++++++ createChannelMapFile_KSW.m | 4 +- 4 files changed, 122 insertions(+), 15 deletions(-) rename StandardConfig_KSW.m => ConfigurationFiles/KilosortConfiguration_Omid.m (95%) mode change 100755 => 100644 create mode 100644 KilosortConfiguration.m diff --git a/StandardConfig_KSW.m b/ConfigurationFiles/KilosortConfiguration_Omid.m old mode 100755 new mode 100644 similarity index 95% rename from StandardConfig_KSW.m rename to ConfigurationFiles/KilosortConfiguration_Omid.m index 9481262..193e277 --- a/StandardConfig_KSW.m +++ b/ConfigurationFiles/KilosortConfiguration_Omid.m @@ -36,7 +36,7 @@ % end ops.nt0 = round(1.6*ops.fs/1000); % window width in samples. 1.6ms at 20kH corresponds to 32 samples -ops.nNeighPC = min([12 ops.Nchan]); % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) +ops.nNeighPC = min([16 ops.Nchan]); % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) ops.nNeigh = min([16 ops.Nchan]); % visualization only (Phy): number of neighboring templates to retain projections of (16) % options for channel whitening @@ -52,7 +52,7 @@ ops.Nrank = 3; % matrix rank of spike template model (3) ops.nfullpasses = 6; % number of complete passes through data during optimization (6) ops.maxFR = 40000; % maximum number of spikes to extract per batch (20000) -ops.fshigh = 300; % frequency for high pass filtering +ops.fshigh = 500; % frequency for high pass filtering ops.fslow = 8000; % frequency for low pass filtering (optional) ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection ops.scaleproc = 200; % int16 scaling of whitened data @@ -61,7 +61,7 @@ % the following options can improve/deteriorate results. % when multiple values are provided for an option, the first two are beginning and ending anneal values, % the third is the value used in the final pass. -ops.Th = [6 12 12]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.Th = [6 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) ops.lam = [12 40 40]; % large means amplitudes are forced around the mean ([10 30 30]) ops.nannealpasses = 4; % should be less than nfullpasses (4) ops.momentum = 1./[20 800]; % start with high momentum and anneal (1./[20 1000]) @@ -71,7 +71,7 @@ % options for initializing spikes from data ops.initialize = 'fromData'; %'fromData' or 'no' -ops.spkTh = -6; % spike threshold in standard deviations (4) +ops.spkTh = -5; % spike threshold in standard deviations (4) ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) ops.maskMaxChannels = 8; % how many channels to mask up/down ([5]) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 32ca34e..f94f11e 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -1,4 +1,4 @@ -function savepath = KiloSortWrapper(basepath,basename) +function savepath = KiloSortWrapper(basepath,basename,config_version) % Creates channel map from Neuroscope xml files, runs KiloSort and % writes output data in the Neuroscope/Klusters format. % StandardConfig.m should be in the path or copied to the local folder @@ -30,12 +30,21 @@ % addpath(genpath('gitrepositories/npy-matlab')) % path to npy-matlab scripts %% If function is called without argument -if nargin == 0 - [~,basename] = fileparts(cd); - basepath = cd; -elseif nargin == 1 - [~,basename] = fileparts(basepath); - basepath = cd; +switch nargin + case 0 + [~,basename] = fileparts(cd); + basepath = cd; + case 1 + [~,basename] = fileparts(basepath); + basepath = cd; + case 2 + [~,basename] = fileparts(basepath); + basepath = cd; + case 3 + if isempty(basepath) + [~,basename] = fileparts(cd); + basepath = cd; + end end cd(basepath) @@ -43,12 +52,28 @@ disp('Creating ChannelMapFile') createChannelMapFile_KSW(basepath,'staggered'); -%% default options are in parenthesis after the comment +%% Loading configurations XMLFilePath = fullfile(basepath, [basename '.xml']); % if exist(fullfile(basepath,'StandardConfig.m'),'file') %this should actually be unnecessary % addpath(basepath); % end -ops = StandardConfig_KSW(XMLFilePath); +if nargin < 3 + disp('Running Kilosort with standard settings') + ops = KilosortConfiguration(XMLFilePath); +else + disp('Running Kilosort with user specific settings') + config_string = str2func(['KilosortConfiguration_' config_version]); + ops = config_string(XMLFilePath); + clear config_string; +end + +%% % Defining SSD location if any +if isdir('G:\Kilosort') + disp('Creating a temporary dat file on the SSD drive') + ops.fproc = ['G:\Kilosort\temp_wh.dat']; +else + ops.fproc = fullfile(rootpath,'temp_wh.dat'); +end %% if ops.GPU diff --git a/KilosortConfiguration.m b/KilosortConfiguration.m new file mode 100644 index 0000000..3069df0 --- /dev/null +++ b/KilosortConfiguration.m @@ -0,0 +1,82 @@ +function ops = KilosortConfiguration(XMLfile) + +% Loads xml parameters (Neuroscope) +xml = LoadXml(XMLfile); +% Define rootpath +rootpath = fileparts(XMLfile); + +ops.GPU = 1; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first) +ops.parfor = 1; % whether to use parfor to accelerate some parts of the algorithm +ops.verbose = 1; % whether to print command line progress +ops.showfigures = 0; % whether to plot figures during optimization +ops.datatype = 'dat'; % binary ('dat', 'bin') or 'openEphys' +ops.fbinary = [XMLfile(1:end-3) 'dat']; % will be created for 'openEphys' + +ops.root = rootpath; % 'openEphys' only: where raw files are +ops.fs = xml.SampleRate; % sampling rate + +load(fullfile(rootpath,'chanMap.mat')) +ops.NchanTOT = length(connected); % total number of channels + +ops.Nchan = sum(connected>1e-6); % number of active channels + +templatemultiplier = 8; +ops.Nfilt = ops.Nchan*templatemultiplier - mod(ops.Nchan*templatemultiplier,32); % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) +% if ops.Nfilt > 2024; +% ops.Nfilt = 2024; +% elseif ops.Nfilt == 0 +% ops.Nfilt = 32; +% end +ops.nt0 = round(1.6*ops.fs/1000); % window width in samples. 1.6ms at 20kH corresponds to 32 samples + +ops.nNeighPC = min([16 ops.Nchan]); % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) +ops.nNeigh = min([16 ops.Nchan]); % visualization only (Phy): number of neighboring templates to retain projections of (16) + +% options for channel whitening +ops.whitening = 'full'; % type of whitening (default 'full', for 'noSpikes' set options for spike detection below) +ops.nSkipCov = 1; % compute whitening matrix from every N-th batch (1) +ops.whiteningRange = min([64 ops.Nchan]); % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32) + +% define the channel map as a filename (string) or simply an array +ops.chanMap = fullfile(rootpath,'chanMap.mat'); % make this file using createChannelMapFile.m +ops.criterionNoiseChannels = 0.00001; % fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info). + +% other options for controlling the model and optimization +ops.Nrank = 3; % matrix rank of spike template model (3) +ops.nfullpasses = 6; % number of complete passes through data during optimization (6) +ops.maxFR = 40000; % maximum number of spikes to extract per batch (20000) +ops.fshigh = 500; % frequency for high pass filtering +ops.fslow = 8000; % frequency for low pass filtering (optional) +ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection +ops.scaleproc = 200; % int16 scaling of whitened data +ops.NT = 4*32*1028+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) for GPU should be multiple of 32 + ntbuff + +% the following options can improve/deteriorate results. +% when multiple values are provided for an option, the first two are beginning and ending anneal values, +% the third is the value used in the final pass. +ops.Th = [6 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.lam = [12 40 40]; % large means amplitudes are forced around the mean ([10 30 30]) +ops.nannealpasses = 4; % should be less than nfullpasses (4) +ops.momentum = 1./[20 800]; % start with high momentum and anneal (1./[20 1000]) +ops.shuffle_clusters = 1; % allow merges and splits during optimization (1) +ops.mergeT = .1; % upper threshold for merging (.1) +ops.splitT = .1; % lower threshold for splitting (.1) + +% options for initializing spikes from data +ops.initialize = 'fromData'; %'fromData' or 'no' +ops.spkTh = -5; % spike threshold in standard deviations (4) +ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) +ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) +ops.maskMaxChannels = 8; % how many channels to mask up/down ([5]) +ops.crit = .65; % upper criterion for discarding spike repeates (0.65) +ops.nFiltMax = 80000; % maximum "unique" spikes to consider (10000) + +% load predefined principal components (visualization only (Phy): used for features) +dd = load('PCspikes2.mat'); % you might want to recompute this from your own data +ops.wPCA = dd.Wi(:,1:7); % PCs + +% options for posthoc merges (under construction) +ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) +ops.epu = Inf; +ops.ForceMaxRAMforDat = 15000000000; % maximum RAM the algorithm will try to use; on Windows it will autodetect. +end diff --git a/createChannelMapFile_KSW.m b/createChannelMapFile_KSW.m index 2d11045..77ed426 100644 --- a/createChannelMapFile_KSW.m +++ b/createChannelMapFile_KSW.m @@ -85,9 +85,9 @@ function createChannelMapFile_Local(basepath,electrode_type) x(1:extrachannels) = 18*(-1).^[1:extrachannels]; y(find(x == 2*18)) = [1:length(find(x == 2*18))]*-28; - y(find(x == 18)) = [1:length(find(x == 18))]*-28+14; + y(find(x == 18)) = [1:length(find(x == 18))]*-28-14; y(find(x == 0)) = [1:length(find(x == 0))]*-28; - y(find(x == -18)) = [1:length(find(x == -18))]*-28+14; + y(find(x == -18)) = [1:length(find(x == -18))]*-28-14; y(find(x == 2*-18)) = [1:length(find(x == 2*-18))]*-28; x = x+a*200; From ac4c964e86c6c057a6e239c1818144c535893e7d Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Fri, 9 Mar 2018 17:42:51 -0500 Subject: [PATCH 04/35] Cleaning up outdated dependencies --- createChannelMapFile.m | 55 --------- fitTemplates.m | 244 ------------------------------------ fullMPMU.m | 273 ----------------------------------------- master_file.m | 55 --------- npy-matlab | 1 - 5 files changed, 628 deletions(-) delete mode 100755 createChannelMapFile.m delete mode 100755 fitTemplates.m delete mode 100755 fullMPMU.m delete mode 100755 master_file.m delete mode 160000 npy-matlab diff --git a/createChannelMapFile.m b/createChannelMapFile.m deleted file mode 100755 index e49144d..0000000 --- a/createChannelMapFile.m +++ /dev/null @@ -1,55 +0,0 @@ -function createChannelMapFile(basepath,basename) -% create a channel map file - -if ~exist('basepath','var') - basepath = cd; -end - - - - -XMLfile = [basepath '/' basename '.xml']; - - [par, rxml] = LoadXml(XMLfile); - -% add bad channel-handling - -Nchannels = par.nChannels; - -connected = true(Nchannels, 1); -chanMap = 1:Nchannels; -chanMap0ind = chanMap - 1; - -% xcoords = ones(Nchannels,1); -% ycoords = [1:Nchannels]'; -xcoords = []; -ycoords = []; -for a= 1:length(par.AnatGrps)%being super lazy and making this map with loops - x = []; - y = []; - tchannels = par.AnatGrps(a).Channels; - for i =1:length(tchannels) -% if ~ismember(tchannels(i),badchannels) - x(i) = length(tchannels)-i; - y(i) = -i*1; - if mod(i,2) - x(i) = -x(i); - end -% end - end - x = x+a*100; - xcoords = cat(1,xcoords,x(:)); - ycoords = cat(1,ycoords,y(:)); -end - -kcoords = zeros(Nchannels,1); -for a= 1:length(par.AnatGrps) - kcoords(par.AnatGrps(a).Channels+1) = a; -end - - -save(fullfile(basepath,'chanMap.mat'), ... - 'chanMap','connected', 'xcoords', 'ycoords', 'kcoords', 'chanMap0ind') - -%% -% \ No newline at end of file diff --git a/fitTemplates.m b/fitTemplates.m deleted file mode 100755 index 1797b13..0000000 --- a/fitTemplates.m +++ /dev/null @@ -1,244 +0,0 @@ -function rez = fitTemplates(rez, DATA, uproj) - -nt0 = rez.ops.nt0; -rez.ops.nt0min = ceil(20 * nt0/61); - -ops = rez.ops; - -rng('default'); -rng(1); - -Nbatch = rez.temp.Nbatch; -Nbatch_buff = rez.temp.Nbatch_buff; - -Nfilt = ops.Nfilt; %256+128; - -ntbuff = ops.ntbuff; -NT = ops.NT; - -Nrank = ops.Nrank; -Th = ops.Th; -maxFR = ops.maxFR; - -Nchan = ops.Nchan; - -batchstart = 0:NT:NT*(Nbatch-Nbatch_buff); - -delta = NaN * ones(Nbatch, 1); -iperm = randperm(Nbatch); - -switch ops.initialize - case 'fromData' - WUinit = optimizePeaks(ops,uproj);%does a scaled kmeans - dWU = WUinit(:,:,1:Nfilt); - % dWU = alignWU(dWU); - otherwise - initialize_waves0; - ipck = randperm(size(Winit,2), Nfilt); - W = []; - U = []; - for i = 1:Nrank - W = cat(3, W, Winit(:, ipck)/Nrank); - U = cat(3, U, Uinit(:, ipck)); - end - W = alignW(W, ops); - - dWU = zeros(nt0, Nchan, Nfilt, 'single'); - for k = 1:Nfilt - Ut = permute(U(:,k,:),[1 3 2]); - wu = squeeze(W(:,k,:)) * Ut'; - newnorm = sum(wu(:).^2).^.5; - W(:,k,:) = W(:,k,:)/newnorm; - - dWU(:,:,k) = 10 * wu; - end - WUinit = dWU; -end -[W, U, mu, UtU, nu] = decompose_dWU(ops, dWU, Nrank, rez.ops.kcoords); -W0 = W; -W0(NT, 1) = 0; -fW = fft(W0, [], 1); -fW = conj(fW); - -nspikes = zeros(Nfilt, Nbatch); -lam = ones(Nfilt, 1, 'single'); - -freqUpdate = 100 * 4; -iUpdate = 1:freqUpdate:Nbatch; - - -dbins = zeros(100, Nfilt); -dsum = 0; -miniorder = repmat(iperm, 1, ops.nfullpasses); -% miniorder = repmat([1:Nbatch Nbatch:-1:1], 1, ops.nfullpasses/2); - -i = 1; % first iteration - -epu = ops.epu; - - -%% -% pmi = exp(-1./exp(linspace(log(ops.momentum(1)), log(ops.momentum(2)), Nbatch*ops.nannealpasses))); -pmi = exp(-1./linspace(1/ops.momentum(1), 1/ops.momentum(2), Nbatch*ops.nannealpasses)); -% pmi = exp(-linspace(ops.momentum(1), ops.momentum(2), Nbatch*ops.nannealpasses)); - -% pmi = linspace(ops.momentum(1), ops.momentum(2), Nbatch*ops.nannealpasses); -Thi = linspace(ops.Th(1), ops.Th(2), Nbatch*ops.nannealpasses); -if ops.lam(1)==0 - lami = linspace(ops.lam(1), ops.lam(2), Nbatch*ops.nannealpasses); -else - lami = exp(linspace(log(ops.lam(1)), log(ops.lam(2)), Nbatch*ops.nannealpasses)); -end - -if Nbatch_buff1 && ismember(rem(i,Nbatch), iUpdate) %&& i>Nbatch - dWU = gather_try(dWU); - - % break bimodal clusters and remove low variance clusters - if ops.shuffle_clusters &&... - i>Nbatch && rem(rem(i,Nbatch), 4*400)==1 % iNfilt; - j = Nfilt -9; - end - plot(log(1+NSP(j + [0:1:9])), mu(j+ [0:1:9]), 'o'); - xlabel('log of number of spikes') - ylabel('amplitude of template') - hold all - end - axis tight; - title(sprintf('%d ', nswitch)); - subplot(2,2,2) - plot(W(:,:,1)) - title('timecourses of top PC') - - subplot(2,2,3) - imagesc(U(:,:,1)) - title('spatial mask of top PC') - - drawnow - end - % break if last iteration reached - if i>Nbatch * ops.nfullpasses; break; end - - % record the error function for this iteration - rez.errall(ceil(i/freqUpdate)) = nanmean(delta); - - end - - % select batch and load from RAM or disk - ibatch = miniorder(i); - if ibatch>Nbatch_buff - offset = 2 * ops.Nchan*batchstart(ibatch-Nbatch_buff); - fseek(fid, offset, 'bof'); - dat = fread(fid, [NT ops.Nchan], '*int16'); - else - dat = DATA(:,:,Nbatch_buff); - end - - % move data to GPU and scale it - if ops.GPU - dataRAW = gpuArray(dat); - else - dataRAW = dat; - end - dataRAW = single(dataRAW); - dataRAW = dataRAW / ops.scaleproc; - - % project data in low-dim space - data = dataRAW * U(:,:); - - if ops.GPU - % run GPU code to get spike times and coefficients - [dWU, ~, id, x,Cost, nsp] = ... - mexMPregMU(Params,dataRAW,W,data,UtU,mu, lam .* (20./mu).^2, dWU, nu); - else - [dWU, ~, id, x,Cost, nsp] = ... - mexMPregMUcpu(Params,dataRAW,fW,data,UtU,mu, lam .* (20./mu).^2, dWU, nu, ops); - end - - dbins = .9975 * dbins; % this is a hard-coded forgetting factor, needs to become an option - if ~isempty(id) - % compute numbers of spikes - nsp = gather_try(nsp(:)); - nspikes(:, ibatch) = nsp; - - % bin the amplitudes of the spikes - xround = min(max(1, int32(x)), 100); - - dbins(xround + id * size(dbins,1)) = dbins(xround + id * size(dbins,1)) + 1; - - % estimate cost function at this time step - delta(ibatch) = sum(Cost)/1e3; - end - - % update status - if ops.verbose && rem(i,20)==1 - nsort = sort(round(sum(nspikes,2)), 'descend'); - fprintf(repmat('\b', 1, numel(msg))); - msg = sprintf('Time %2.2f, batch %d/%d, mu %2.2f, neg-err %2.6f, NTOT %d, n100 %d, n200 %d, n300 %d, n400 %d\n', ... - toc, i,Nbatch* ops.nfullpasses,nanmean(mu(:)), nanmean(delta), round(sum(nsort)), ... - nsort(min(size(W,2), 100)), nsort(min(size(W,2), 200)), ... - nsort(min(size(W,2), 300)), nsort(min(size(W,2), 400))); - fprintf(msg); - end - - % increase iteration counter - i = i+1; -end - -% close the data file if it has been used -if Nbatch_buff100); - cr = mWtW .* (vld * vld'); - cr(isnan(cr)) = 0; - [~, iNgsort] = sort(cr, 1, 'descend'); - - % save full similarity score - rez.simScore = cr; - maskTT = zeros(Nfilt, 'single'); - rez.iNeigh = iNgsort(1:nNeigh, :); - for i = 1:Nfilt - maskTT(rez.iNeigh(:,i),i) = 1; - end -end -if ~isempty(ops.nNeighPC) - nNeighPC = ops.nNeighPC; - load PCspikes - ixt = round(linspace(1, size(Wi,1), ops.nt0)); - Wi = Wi(ixt, 1:3); - rez.cProjPC = zeros(5e6, 3*nNeighPC, 'single'); - - % sort best channels - [~, iNch] = sort(abs(U(:,:,1)), 1, 'descend'); - maskPC = zeros(Nchan, Nfilt, 'single'); - rez.iNeighPC = iNch(1:nNeighPC, :); - for i = 1:Nfilt - maskPC(rez.iNeighPC(:,i),i) = 1; - end - maskPC = repmat(maskPC, 3, 1); -end - -irun = 0; -i1nt0 = int32([1:nt0])'; -%% -LAM = lam .* (20./mu).^2; - -NT = ops.NT; -batchstart = 0:NT:NT*(Nbatch-Nbatch_buff); - -for ibatch = 1:Nbatch - if ibatch>Nbatch_buff - offset = 2 * ops.Nchan*batchstart(ibatch-Nbatch_buff); % - ioffset; - fseek(fid, offset, 'bof'); - dat = fread(fid, [NT ops.Nchan], '*int16'); - else - dat = DATA(:,:,Nbatch_buff); - end - if ops.GPU - dataRAW = gpuArray(dat); - else - dataRAW = dat; - end - dataRAW = single(dataRAW); - dataRAW = dataRAW / ops.scaleproc; - - % project data in low-dim space - if ops.GPU - data = gpuArray.zeros(NT, Nfilt, Nrank, 'single'); - else - data = zeros(NT, Nfilt, Nrank, 'single'); - end - for irank = 1:Nrank - data(:,:,irank) = dataRAW * U(:,:,irank); - end - data = reshape(data, NT, Nfilt*Nrank); - - if ops.GPU - [st, id, x, errC, PCproj] ... - = mexMPmuFEAT(Params,data,W,WtW, mu, lam .* (20./mu).^2, nu); - else - [st, id, x, errC, PCproj]= cpuMPmuFEAT(Params,data,fW,WtW, mu, lam .* (20./mu).^2, nu, ops); - end - - if ~isempty(st) - if ~isempty(ops.nNeighPC) - % PCA coefficients - inds = repmat(st', nt0, 1) + repmat(i1nt0, 1, numel(st)); - try datSp = dataRAW(inds(:), :); - catch - datSp = dataRAW(inds(:), :); - end - datSp = reshape(datSp, [size(inds) Nchan]); - coefs = reshape(Wi' * reshape(datSp, nt0, []), size(Wi,2), numel(st), Nchan); - coefs = reshape(permute(coefs, [3 1 2]), [], numel(st)); - coefs = coefs .* maskPC(:, id+1); - iCoefs = reshape(find(maskPC(:, id+1)>0), 3*nNeighPC, []); - rez.cProjPC(irun + (1:numel(st)), :) = gather(coefs(iCoefs)'); - end - if ~isempty(ops.nNeigh) - % template coefficients - % transform coefficients - PCproj = bsxfun(@rdivide, ... - bsxfun(@plus, PCproj, LAM.*mu), sqrt(1+LAM)); - - PCproj = maskTT(:, id+1) .* PCproj; - iPP = reshape(find(maskTT(:, id+1)>0), nNeigh, []); - rez.cProj(irun + (1:numel(st)), :) = PCproj(iPP)'; - end - % increment number of spikes - irun = irun + numel(st); - - if ibatch==1; - ioffset = 0; - else - ioffset = ops.ntbuff; - end - st = st - ioffset; - - % nspikes2(1:size(W,2)+1, ibatch) = histc(id, 0:1:size(W,2)); - STT = cat(2, ops.nt0min + double(st) +(NT-ops.ntbuff)*(ibatch-1), ... - double(id)+1, double(x), ibatch*ones(numel(x),1)); - st3 = cat(1, st3, STT); - end - if rem(ibatch,100)==1 -% nsort = sort(sum(nspikes2,2), 'descend'); - fprintf(repmat('\b', 1, numel(msg))); - msg = sprintf('Time %2.2f, batch %d/%d, NTOT %d\n', ... - toc, ibatch,Nbatch, size(st3,1)); - fprintf(msg); - - end -end -%% -[~, isort] = sort(st3(:,1), 'ascend'); -st3 = st3(isort,:); - -rez.st3 = st3; -if ~isempty(ops.nNeighPC) - % re-sort coefficients for projections - rez.cProjPC(irun+1:end, :) = []; - rez.cProjPC = reshape(rez.cProjPC, size(rez.cProjPC,1), [], 3); - rez.cProjPC = rez.cProjPC(isort, :,:); - for ik = 1:Nfilt - iSp = rez.st3(:,2)==ik; - OneToN = 1:nNeighPC; - [~, isortNeigh] = sort(rez.iNeighPC(:,ik), 'ascend'); - OneToN(isortNeigh) = OneToN; - rez.cProjPC(iSp, :,:) = rez.cProjPC(iSp, OneToN, :); - end - - rez.cProjPC = permute(rez.cProjPC, [1 3 2]); -end -if ~isempty(ops.nNeigh) - rez.cProj(irun+1:end, :) = []; - rez.cProj = rez.cProj(isort, :); - - % re-index the template coefficients - for ik = 1:Nfilt - iSp = rez.st3(:,2)==ik; - OneToN = 1:nNeigh; - [~, isortNeigh] = sort(rez.iNeigh(:,ik), 'ascend'); - OneToN(isortNeigh) = OneToN; - rez.cProj(iSp, :) = rez.cProj(iSp, OneToN); - end -end - - -%% -% rez.ops = ops; -rez.W = W; -rez.U = U; -rez.mu = mu; - -rez.t2p = []; -for i = 1:Nfilt - wav0 = W(:,i,1); - wav0 = my_conv(wav0', .5)'; - [~, itrough] = min(wav0); - [~, t2p] = max(wav0(itrough:end)); - rez.t2p(i,1) = t2p; - rez.t2p(i,2) = itrough; -end - -rez.nbins = histc(rez.st3(:,2), .5:1:Nfilt+1); - -[~, rez.ypos] = max(rez.U(:,:,1), [], 1); -if Nbatch_buff Date: Fri, 30 Mar 2018 17:00:59 -0400 Subject: [PATCH 05/35] New conversion script: Kilosort2Neurosuite The exportion function now filters the dat before waveform extraction, and these waveforms are used to extract the PCAs. The standard config saves the xml structure to the Kilosort ops structure, which is inherited in the rez structure. Correcting the input to the KilosortWrapper --- .../KilosortConfiguration_Omid.m | 13 +- ConvertKilosort2Neurosuite_KSW.m | 2 +- KiloSortWrapper.m | 11 +- Kilosort2Neurosuite.m | 265 ++++++++++++++++++ KilosortConfiguration.m | 13 +- 5 files changed, 286 insertions(+), 18 deletions(-) create mode 100644 Kilosort2Neurosuite.m diff --git a/ConfigurationFiles/KilosortConfiguration_Omid.m b/ConfigurationFiles/KilosortConfiguration_Omid.m index 193e277..95e015d 100644 --- a/ConfigurationFiles/KilosortConfiguration_Omid.m +++ b/ConfigurationFiles/KilosortConfiguration_Omid.m @@ -61,8 +61,8 @@ % the following options can improve/deteriorate results. % when multiple values are provided for an option, the first two are beginning and ending anneal values, % the third is the value used in the final pass. -ops.Th = [6 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) -ops.lam = [12 40 40]; % large means amplitudes are forced around the mean ([10 30 30]) +ops.Th = [5 6 6]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.lam = [10 20 20]; % large means amplitudes are forced around the mean ([10 30 30]) ops.nannealpasses = 4; % should be less than nfullpasses (4) ops.momentum = 1./[20 800]; % start with high momentum and anneal (1./[20 1000]) ops.shuffle_clusters = 1; % allow merges and splits during optimization (1) @@ -70,13 +70,13 @@ ops.splitT = .1; % lower threshold for splitting (.1) % options for initializing spikes from data -ops.initialize = 'fromData'; %'fromData' or 'no' -ops.spkTh = -5; % spike threshold in standard deviations (4) +ops.initialize = 'no'; %'fromData' or 'no' +ops.spkTh = 4; % spike threshold in standard deviations (4) ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) ops.maskMaxChannels = 8; % how many channels to mask up/down ([5]) ops.crit = .65; % upper criterion for discarding spike repeates (0.65) -ops.nFiltMax = 80000; % maximum "unique" spikes to consider (10000) +ops.nFiltMax = 10000; % maximum "unique" spikes to consider (10000) % load predefined principal components (visualization only (Phy): used for features) dd = load('PCspikes2.mat'); % you might want to recompute this from your own data @@ -86,4 +86,7 @@ ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) ops.epu = Inf; ops.ForceMaxRAMforDat = 15000000000; % maximum RAM the algorithm will try to use; on Windows it will autodetect. + +% Saving xml content to ops strucuture +ops.xml = xml; end diff --git a/ConvertKilosort2Neurosuite_KSW.m b/ConvertKilosort2Neurosuite_KSW.m index 28f9bff..192b833 100755 --- a/ConvertKilosort2Neurosuite_KSW.m +++ b/ConvertKilosort2Neurosuite_KSW.m @@ -22,7 +22,7 @@ function ConvertKilosort2Neurosuite_KSW(rez) [~,basename] = fileparts(cd); basepath = cd; end -if ~exist('rez','var'); +if ~exist('rez','var') load(fullfile(basepath,'rez.mat')) end diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index f94f11e..2405eeb 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -36,10 +36,11 @@ basepath = cd; case 1 [~,basename] = fileparts(basepath); - basepath = cd; case 2 - [~,basename] = fileparts(basepath); - basepath = cd; + if isempty(basepath) + [~,basename] = fileparts(basepath); + basepath = cd; + end case 3 if isempty(basepath) [~,basename] = fileparts(cd); @@ -118,8 +119,8 @@ rezToPhy_KSW(rez); %% save python results file for Klusters -% disp('Converting to Klusters format') -% ConvertKilosort2Neurosuite_KSW(rez); +disp('Converting to Klusters format') +Kilosort2Neurosuite(rez) %% Remove temporary file delete(ops.fproc); diff --git a/Kilosort2Neurosuite.m b/Kilosort2Neurosuite.m new file mode 100644 index 0000000..bfa86d7 --- /dev/null +++ b/Kilosort2Neurosuite.m @@ -0,0 +1,265 @@ +function Kilosort2Neurosuite(rez) +% Converts KiloSort output (.rez structure) to Neurosuite files: fet,res,clu,spk files. +% Based on the GPU enable filter from Kilosort and fractions from Brendon +% Watson's code for saving Neurosuite files. +% +% 1) Waveforms are extracted from the dat file via GPU enabled filters. +% 2) Features are calculated in parfor loops. +% +% Inputs: +% rez - rez structure from Kilosort +% +% By Peter Petersen 2018 +% petersen.peter@gmail.com + +t1 = tic; +spikeTimes = uint64(rez.st3(:,1)); % uint64 +spikeTemplates = uint32(rez.st3(:,2)); % uint32 % template id for each spike +kcoords = rez.ops.kcoords; +basename = rez.ops.basename; + +Nchan = rez.ops.Nchan; +samples = rez.ops.nt0; + +templates = zeros(Nchan, size(rez.W,1), rez.ops.Nfilt, 'single'); +for iNN = 1:rez.ops.Nfilt + templates(:,:,iNN) = squeeze(rez.U(:,iNN,:)) * squeeze(rez.W(:,iNN,:))'; +end +amplitude_max_channel = []; +for i = 1:size(templates,3) + [~,amplitude_max_channel(i)] = max(range(templates(:,:,i)')); +end + +template_kcoords = kcoords(amplitude_max_channel); +kcoords2 = unique(template_kcoords); +ia = []; +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + disp(['-Loading data for spike group ', num2str(kcoords3)]) + template_index = find(template_kcoords == kcoords3); + ia{i} = find(ismember(spikeTemplates,template_index)); +end + +rez.ia = ia; +toc(t1) +disp('Saving .clu files to disk (cluster indexes)') +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + disp(['-Saving .clu file for group ', num2str(kcoords3)]) + tclu = spikeTemplates(ia{i}); + tclu = cat(1,length(unique(tclu)),double(tclu)); + cluname = fullfile([basename '.clu.' num2str(kcoords3)]); + fid=fopen(cluname,'w'); + fprintf(fid,'%.0f\n',tclu); + fclose(fid); + clear fid +end +toc(t1) + +disp('Saving .res files to disk (spike times)') +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + tspktimes = spikeTimes(ia{i}); + disp(['-Saving .res file for group ', num2str(kcoords3)]) + resname = fullfile([basename '.res.' num2str(kcoords3)]); + fid=fopen(resname,'w'); + fprintf(fid,'%.0f\n',tspktimes); + fclose(fid); + clear fid +end +toc(t1) + +disp('Extracting waveforms') +waveforms_all = Kilosort_ExtractWaveforms(rez); +toc(t1) + +disp('Saving .spk files to disk (waveforms)') +for i = 1:length(kcoords2) + disp(['-Saving .spk for group ', num2str(kcoords2(i))]) + fid=fopen([basename,'.spk.',num2str(kcoords2(i))],'w'); + fwrite(fid,waveforms_all{i}(:),'int16'); + fclose(fid); +end +toc(t1) + +disp('Computing PCAs') +% Starting parpool if stated in the Kilosort settings +if (rez.ops.parfor & isempty(gcp('nocreate'))); parpool; end + +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + disp(['-Computing PCAs for group ', num2str(kcoords3)]) + PCAs_global = zeros(3,sum(kcoords==kcoords3),length(ia{i})); + waveforms = waveforms_all{i}; + + waveforms2 = reshape(waveforms,[size(waveforms,1)*size(waveforms,2),size(waveforms,3)]); + wranges = int64(range(waveforms2,1)); + wpowers = int64(sum(waveforms2.^2,1)/size(waveforms2,1)/100); + + % Calculating PCAs in parallel if stated in ops.parfor + if isempty(gcp('nocreate')) + for k = 1:size(waveforms,1) + PCAs_global(:,k,:) = pca(zscore(permute(waveforms(k,:,:),[2,3,1]),[],2),'NumComponents',3)'; + end + else + parfor k = 1:size(waveforms,1) + PCAs_global(:,k,:) = pca(zscore(permute(waveforms(k,:,:),[2,3,1]),[],2),'NumComponents',3)'; + end + end + disp(['-Saving .fet files for group ', num2str(kcoords3)]) + PCAs_global2 = reshape(PCAs_global,size(PCAs_global,1)*size(PCAs_global,2),size(PCAs_global,3)); + factor = (2^15)./max(abs(PCAs_global2')); + PCAs_global2 = int64(PCAs_global2 .* factor'); + + fid=fopen([basename,'.fet.',num2str(kcoords3)],'w'); + Fet = double([PCAs_global2; wranges; wpowers; spikeTimes(ia{i})']); + nFeatures = size(Fet, 1); + formatstring = '%d'; + for ii=2:nFeatures + formatstring = [formatstring,'\t%d']; + end + formatstring = [formatstring,'\n']; + + fprintf(fid, '%d\n', nFeatures); + fprintf(fid,formatstring,Fet); + fclose(fid); +end +toc(t1) +disp('Complete!') + + + + function waveforms_all = Kilosort_ExtractWaveforms(rez) + % Extracts waveforms from a dat file using GPU enable filters. + % Based on the GPU enable filter from Kilosort. + % All settings and content are extracted from the rez input structure + % + % Inputs: + % rez - rez structure from Kilosort + % + % Outputs: + % waveforms_all - structure with extracted waveforms + + % Extracting content from the .rez file + ops = rez.ops; + NT = ops.NT; + d = dir(ops.fbinary); + NchanTOT = ops.NchanTOT; + chanMap = ops.chanMap; + chanMapConn = chanMap(rez.connected>1e-6); + kcoords = ops.kcoords; + ia = rez.ia; + spikeTimes = rez.st3(:,1); + + if ispc + dmem = memory; + memfree = dmem.MemAvailableAllArrays/8; + memallocated = min(ops.ForceMaxRAMforDat, dmem.MemAvailableAllArrays) - memfree; + memallocated = max(0, memallocated); + else + memallocated = ops.ForceMaxRAMforDat; + end + ops.ForceMaxRAMforDat = 10000000000; + memallocated = ops.ForceMaxRAMforDat; + nint16s = memallocated/2; + + NTbuff = NT + 4*ops.ntbuff; + Nbatch = ceil(d.bytes/2/NchanTOT /(NT-ops.ntbuff)); + Nbatch_buff = floor(4/5 * nint16s/ops.Nchan /(NT-ops.ntbuff)); % factor of 4/5 for storing PCs of spikes + Nbatch_buff = min(Nbatch_buff, Nbatch); + + DATA =zeros(NT, NchanTOT,Nbatch_buff,'int16'); + + if isfield(ops,'fslow')&&ops.fslow1e-6); % number of active channels templatemultiplier = 8; -ops.Nfilt = ops.Nchan*templatemultiplier - mod(ops.Nchan*templatemultiplier,32); % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) -% if ops.Nfilt > 2024; -% ops.Nfilt = 2024; -% elseif ops.Nfilt == 0 -% ops.Nfilt = 32; -% end +ops.Nfilt = ops.Nchan*templatemultiplier - mod(ops.Nchan*templatemultiplier,32); % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) + ops.nt0 = round(1.6*ops.fs/1000); % window width in samples. 1.6ms at 20kH corresponds to 32 samples ops.nNeighPC = min([16 ops.Nchan]); % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) @@ -64,7 +60,7 @@ % options for initializing spikes from data ops.initialize = 'fromData'; %'fromData' or 'no' -ops.spkTh = -5; % spike threshold in standard deviations (4) +ops.spkTh = -4; % spike threshold in standard deviations (4) ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) ops.maskMaxChannels = 8; % how many channels to mask up/down ([5]) @@ -79,4 +75,7 @@ ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) ops.epu = Inf; ops.ForceMaxRAMforDat = 15000000000; % maximum RAM the algorithm will try to use; on Windows it will autodetect. + +% Saving xml content to ops strucuture +ops.xml = xml; end From 4ccc3734514ea25a0ed0049b5488f58b9a5da537 Mon Sep 17 00:00:00 2001 From: Peter C Petersen Date: Sat, 31 Mar 2018 22:14:31 -0400 Subject: [PATCH 06/35] Update README.md --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bc4d8fc..3fc0921 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,17 @@ # KilosortWrapper Allows you to load Kilosort from a .xml and a .dat file compatible with Neurosuite +## Installation +Download and add the KilosortWrapper directory to your matlab path. + ## Settings -Settings are defined in the StandardSettings file +Most settings are defined in the KilosortConfiguration.m file. Some general settings are defined in the KilosortWrapper file, including: + +* Path to SSD +* Process in subdirectory +Supply a config version input, to use another config file (configuration files should be stored in the ConfigurationFiles folder). + ## Features Skip channels: To skip dead channels, select the skip function in Neuroscope or NDManager Define probe layouts: The wrapper now supports probes with staggered, poly 3 and poly 5 probe layouts... From 10f830cee4c44f487dc87de60ec1c95e16aeeb76 Mon Sep 17 00:00:00 2001 From: Peter C Petersen Date: Sat, 31 Mar 2018 22:19:27 -0400 Subject: [PATCH 07/35] Update README.md --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3fc0921..01b62a0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,10 @@ Define probe layouts: The wrapper now supports probes with staggered, poly 3 and Allows you to save the output from Kilosort to a sub directory. ## Outputs -The Kilosort wrapper allows you to save the output in Neurosuite compatible files or for Phy. +The Kilosort wrapper allows you to save the output in Neurosuite and Phy compatible files. -### Phy -Saved a channel groups file with information about which shanks the channels are asigned to. +### Phy (rezToPhy_KSW) + + +### Neurosuite (Kilosort2Neurosuite) +Creates all classical files used in the Neurosuite format. For this the dat file is filtered, waveforms are extracted and global PCA features are calculated. From dee83952078d08b862fe4436ace75b86e03025d9 Mon Sep 17 00:00:00 2001 From: Brendon Watson Date: Sun, 8 Apr 2018 22:32:35 -0400 Subject: [PATCH 08/35] Bug fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - KilosortWrapper.m: Changed a call to “rootpath” to ‘basepath’, since footpath has not been defined - Kilosort2Neurosuite: Fixed issue with not importing xml info from ops --- KiloSortWrapper.m | 2 +- Kilosort2Neurosuite.m | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 2405eeb..8876a89 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -73,7 +73,7 @@ disp('Creating a temporary dat file on the SSD drive') ops.fproc = ['G:\Kilosort\temp_wh.dat']; else - ops.fproc = fullfile(rootpath,'temp_wh.dat'); + ops.fproc = fullfile(basepath,'temp_wh.dat'); end %% diff --git a/Kilosort2Neurosuite.m b/Kilosort2Neurosuite.m index bfa86d7..07cb08f 100644 --- a/Kilosort2Neurosuite.m +++ b/Kilosort2Neurosuite.m @@ -178,6 +178,7 @@ function Kilosort2Neurosuite(rez) if isfield(ops,'xml') disp('Loading xml from rez for probe layout') + xml = ops.xml; elseif exist(fullfile(ops.root,[ops.basename,'.xml']))==2 disp('Loading xml for probe layout from root folder') xml = LoadXml(fullfile(ops.root,[ops.basename,'.xml'])); @@ -189,8 +190,8 @@ function Kilosort2Neurosuite(rez) waveforms_all = []; kcoords2 = unique(ops.kcoords); - channel_order = []; - indicesTokeep = []; + channel_order = {}; + indicesTokeep = {}; for i = 1:length(kcoords2) kcoords3 = kcoords2(i); waveforms_all{i} = zeros(sum(kcoords==kcoords3),ops.nt0,size(rez.ia{i},1)); From ced23757764dd5c2f72f1b87ec6bea3adaa085e1 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Thu, 3 May 2018 18:17:27 -0400 Subject: [PATCH 09/35] changes to the input call, configuration file and channelmap file Changed input structure of KilosortWrapper to varargin to allow for more smooth handling of multiple inputs. Added specification of export of files in the configuration. Checks that .xml and .dat files exist in directory specified. Reads in the correct xml file. --- KiloSortWrapper.m | 73 +++++++++++++++++++------------------- KilosortConfiguration.m | 5 +++ createChannelMapFile_KSW.m | 9 +++-- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 2405eeb..3cce11a 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -1,4 +1,4 @@ -function savepath = KiloSortWrapper(basepath,basename,config_version) +function savepath = KiloSortWrapper(varargin) % Creates channel map from Neuroscope xml files, runs KiloSort and % writes output data in the Neuroscope/Klusters format. % StandardConfig.m should be in the path or copied to the local folder @@ -14,9 +14,10 @@ % INPUTS % basepath path to the folder containing the data % basename file basenames (of the dat and xml files) +% config_version % % Dependencies: KiloSort (https://github.com/cortex-lab/KiloSort) - +% % Copyright (C) 2016 Brendon Watson and the Buzsakilab % % This program is free software; you can redistribute it and/or modify @@ -25,40 +26,37 @@ % (at your option) any later version. disp('Running Kilosort spike sorting with the Buzsaki lab wrapper') -%% Addpath if needed -% addpath(genpath('gitrepositories/KiloSort')) % path to kilosort folder -% addpath(genpath('gitrepositories/npy-matlab')) % path to npy-matlab scripts - %% If function is called without argument -switch nargin - case 0 - [~,basename] = fileparts(cd); - basepath = cd; - case 1 - [~,basename] = fileparts(basepath); - case 2 - if isempty(basepath) - [~,basename] = fileparts(basepath); - basepath = cd; - end - case 3 - if isempty(basepath) - [~,basename] = fileparts(cd); - basepath = cd; - end -end +p = inputParser; +basepath = cd; +[~,basename] = fileparts(basepath); + +addParameter(p,'basepath',basepath,@ischar) +addParameter(p,'basename',basename,@ischar) +parse(p,varargin{:}) +basepath = p.Results.basepath; +basename = p.Results.basename; cd(basepath) +% Checking if dat and xml files exist +if ~exist(fullfile(basepath,[basename,'.xml'])) + warning('KilosortWrapper %s.xml file not in path %s',basename,basepath); + return +elseif ~exist(fullfile(basepath,[basename,'.dat'])) + warning('KilosortWrapper %s.dat file not in path %s',basename,basepath) + return +end + %% Creates a channel map file disp('Creating ChannelMapFile') -createChannelMapFile_KSW(basepath,'staggered'); +createChannelMapFile_KSW(basepath,basename,'staggered'); %% Loading configurations XMLFilePath = fullfile(basepath, [basename '.xml']); % if exist(fullfile(basepath,'StandardConfig.m'),'file') %this should actually be unnecessary % addpath(basepath); % end -if nargin < 3 +if ~exist('config_version') disp('Running Kilosort with standard settings') ops = KilosortConfiguration(XMLFilePath); else @@ -79,7 +77,7 @@ %% if ops.GPU disp('Initializing GPU') - gpuDevice(1); % initialize GPU (will erase any existing GPU arrays) + gpudev = gpuDevice(1); % initialize GPU (will erase any existing GPU arrays) end if strcmp(ops.datatype , 'openEphys') ops = convertOpenEphysToRawBInary(ops); % convert data, only for OpenEphys @@ -114,15 +112,18 @@ % rez = merge_posthoc2(rez); save(fullfile(savepath, 'rez.mat'), 'rez', '-v7.3'); -%% save python results file for Phy -disp('Converting to Phy format') -rezToPhy_KSW(rez); - -%% save python results file for Klusters -disp('Converting to Klusters format') -Kilosort2Neurosuite(rez) - -%% Remove temporary file +%% export python results file for Phy +if ops.export.phy + disp('Converting to Phy format') + rezToPhy_KSW(rez); +end +%% export Neurosuite files +if ops.export.neurosuite + disp('Converting to Klusters format') + Kilosort2Neurosuite(rez) +end +%% Remove temporary file and resetting GPU delete(ops.fproc); +reset(gpudev) +gpuDevice([]) disp('Kilosort Processing complete') - diff --git a/KilosortConfiguration.m b/KilosortConfiguration.m index 063011f..fd5a669 100644 --- a/KilosortConfiguration.m +++ b/KilosortConfiguration.m @@ -78,4 +78,9 @@ % Saving xml content to ops strucuture ops.xml = xml; + +% Specify if the output should be exported to Phy and/or Neurosuite +ops.export.phy = 1; +ops.export.neurosuite = 1; + end diff --git a/createChannelMapFile_KSW.m b/createChannelMapFile_KSW.m index 77ed426..b7a1acd 100644 --- a/createChannelMapFile_KSW.m +++ b/createChannelMapFile_KSW.m @@ -1,4 +1,4 @@ -function createChannelMapFile_Local(basepath,electrode_type) +function createChannelMapFile_Local(basepath,basename,electrode_type) % Original function by Brendon and Sam % electrode_type: Two options at this point: 'staggered' or 'neurogrid' % create a channel map file @@ -6,8 +6,11 @@ function createChannelMapFile_Local(basepath,electrode_type) if ~exist('basepath','var') basepath = cd; end -d = dir('*.xml'); -[par,rxml] = LoadXml(fullfile(basepath,d(1).name)); +if ~exist('basename','var') + [~,basename] = fileparts(basepath); +end + +[par,rxml] = LoadXml(fullfile(basepath,[basename,'.xml'])); xml_electrode_type = rxml.child(1).child(4).value; switch(xml_electrode_type) case 'staggered' From 14be828097d4abf8e12a41093f054a2370bee744 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Sun, 6 May 2018 22:08:28 -0400 Subject: [PATCH 10/35] GPU selection to varargin GPU selection to varargin --- KiloSortWrapper.m | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 3cce11a..fd9d5e9 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -1,22 +1,23 @@ function savepath = KiloSortWrapper(varargin) % Creates channel map from Neuroscope xml files, runs KiloSort and -% writes output data in the Neuroscope/Klusters format. -% StandardConfig.m should be in the path or copied to the local folder +% writes output data to Neurosuite format or Phy. % -% USAGE +% USAGE % -% KiloSortWrapper() -% Should be run from the data folder, and file basenames are the -% same as the name as current directory +% KiloSortWrapper() +% Should be run from the data folder, and file basenames are the +% same as the name as current directory % -% KiloSortWrapper(basepath,basenmae) +% KiloSortWrapper(varargin) % -% INPUTS -% basepath path to the folder containing the data -% basename file basenames (of the dat and xml files) -% config_version +% INPUTS +% basepath path to the folder containing the data +% basename file basenames (of the dat and xml files) +% config Specify a configuration file to use from the +% ConfigurationFiles folder. e.g. 'Omid' +% GPU_id Specify the GPU id % -% Dependencies: KiloSort (https://github.com/cortex-lab/KiloSort) +% Dependencies: KiloSort (https://github.com/cortex-lab/KiloSort) % % Copyright (C) 2016 Brendon Watson and the Buzsakilab % @@ -33,12 +34,17 @@ addParameter(p,'basepath',basepath,@ischar) addParameter(p,'basename',basename,@ischar) +addParameter(p,'GPU_id',1,@isnumeric) + parse(p,varargin{:}) + basepath = p.Results.basepath; basename = p.Results.basename; +GPU_id = p.Results.GPU_id; + cd(basepath) -% Checking if dat and xml files exist +%% Checking if dat and xml files exist if ~exist(fullfile(basepath,[basename,'.xml'])) warning('KilosortWrapper %s.xml file not in path %s',basename,basepath); return @@ -56,7 +62,7 @@ % if exist(fullfile(basepath,'StandardConfig.m'),'file') %this should actually be unnecessary % addpath(basepath); % end -if ~exist('config_version') +if ~exist('config') disp('Running Kilosort with standard settings') ops = KilosortConfiguration(XMLFilePath); else @@ -77,7 +83,7 @@ %% if ops.GPU disp('Initializing GPU') - gpudev = gpuDevice(1); % initialize GPU (will erase any existing GPU arrays) + gpudev = gpuDevice(GPU_id); % initialize GPU (will erase any existing GPU arrays) end if strcmp(ops.datatype , 'openEphys') ops = convertOpenEphysToRawBInary(ops); % convert data, only for OpenEphys From fd8576e23b32b3d0eab4f465225bce767b89e533 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Sun, 13 May 2018 12:53:32 -0400 Subject: [PATCH 11/35] Improved the Export function Improved the Export function. Replaed the forloop for each spike with vectorized indexing. --- Kilosort2Neurosuite.m | 65 +++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/Kilosort2Neurosuite.m b/Kilosort2Neurosuite.m index 07cb08f..c03d4f7 100644 --- a/Kilosort2Neurosuite.m +++ b/Kilosort2Neurosuite.m @@ -1,13 +1,17 @@ function Kilosort2Neurosuite(rez) % Converts KiloSort output (.rez structure) to Neurosuite files: fet,res,clu,spk files. % Based on the GPU enable filter from Kilosort and fractions from Brendon -% Watson's code for saving Neurosuite files. +% Watson's code for saving Neurosuite files. + +% The script has a high memory usage as all waveforms are loaded into +% memory at the same time. If you experience a memory error, increase +% your swap/cashe file, and increase the amount of memory MATLAB can use. % % 1) Waveforms are extracted from the dat file via GPU enabled filters. % 2) Features are calculated in parfor loops. % % Inputs: -% rez - rez structure from Kilosort +% rez - rez structure from Kilosort % % By Peter Petersen 2018 % petersen.peter@gmail.com @@ -35,17 +39,19 @@ function Kilosort2Neurosuite(rez) ia = []; for i = 1:length(kcoords2) kcoords3 = kcoords2(i); - disp(['-Loading data for spike group ', num2str(kcoords3)]) + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Loading data for spike group ', num2str(kcoords3),'. ']) template_index = find(template_kcoords == kcoords3); ia{i} = find(ismember(spikeTemplates,template_index)); end - rez.ia = ia; -toc(t1) -disp('Saving .clu files to disk (cluster indexes)') +fprintf('\n'); toc(t1) + +fprintf('\nSaving .clu files to disk (cluster indexes)') for i = 1:length(kcoords2) kcoords3 = kcoords2(i); - disp(['-Saving .clu file for group ', num2str(kcoords3)]) + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Saving .clu file for group ', num2str(kcoords3),'. ']) tclu = spikeTemplates(ia{i}); tclu = cat(1,length(unique(tclu)),double(tclu)); cluname = fullfile([basename '.clu.' num2str(kcoords3)]); @@ -54,41 +60,44 @@ function Kilosort2Neurosuite(rez) fclose(fid); clear fid end -toc(t1) +fprintf('\n'); toc(t1) -disp('Saving .res files to disk (spike times)') +fprintf('\nSaving .res files to disk (spike times)') for i = 1:length(kcoords2) kcoords3 = kcoords2(i); tspktimes = spikeTimes(ia{i}); - disp(['-Saving .res file for group ', num2str(kcoords3)]) + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Saving .res file for group ', num2str(kcoords3),'. ']) resname = fullfile([basename '.res.' num2str(kcoords3)]); fid=fopen(resname,'w'); fprintf(fid,'%.0f\n',tspktimes); fclose(fid); clear fid end -toc(t1) +fprintf('\n'); toc(t1) -disp('Extracting waveforms') +fprintf('\nExtracting waveforms\n') waveforms_all = Kilosort_ExtractWaveforms(rez); -toc(t1) +fprintf('\n'); toc(t1) -disp('Saving .spk files to disk (waveforms)') +fprintf('\nSaving .spk files to disk (waveforms)') for i = 1:length(kcoords2) - disp(['-Saving .spk for group ', num2str(kcoords2(i))]) + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Saving .spk for group ', num2str(kcoords2(i)),'. ']) fid=fopen([basename,'.spk.',num2str(kcoords2(i))],'w'); fwrite(fid,waveforms_all{i}(:),'int16'); fclose(fid); end -toc(t1) +fprintf('\n'); toc(t1) -disp('Computing PCAs') +fprintf('\nComputing PCAs') % Starting parpool if stated in the Kilosort settings if (rez.ops.parfor & isempty(gcp('nocreate'))); parpool; end for i = 1:length(kcoords2) kcoords3 = kcoords2(i); - disp(['-Computing PCAs for group ', num2str(kcoords3)]) + if mod(i,2)==1; fprintf('\n'); end + fprintf(['Computing PCAs for group ', num2str(kcoords3),'. ']) PCAs_global = zeros(3,sum(kcoords==kcoords3),length(ia{i})); waveforms = waveforms_all{i}; @@ -106,7 +115,7 @@ function Kilosort2Neurosuite(rez) PCAs_global(:,k,:) = pca(zscore(permute(waveforms(k,:,:),[2,3,1]),[],2),'NumComponents',3)'; end end - disp(['-Saving .fet files for group ', num2str(kcoords3)]) + fprintf(['Saving .fet files for group ', num2str(kcoords3),'. ']) PCAs_global2 = reshape(PCAs_global,size(PCAs_global,1)*size(PCAs_global,2),size(PCAs_global,3)); factor = (2^15)./max(abs(PCAs_global2')); PCAs_global2 = int64(PCAs_global2 .* factor'); @@ -124,10 +133,8 @@ function Kilosort2Neurosuite(rez) fprintf(fid,formatstring,Fet); fclose(fid); end -toc(t1) -disp('Complete!') - - +fprintf('\n'); toc(t1) +fprintf('\nComplete!') function waveforms_all = Kilosort_ExtractWaveforms(rez) % Extracts waveforms from a dat file using GPU enable filters. @@ -250,17 +257,15 @@ function Kilosort2Neurosuite(rez) for i = 1:length(kcoords2) kcoords3 = kcoords2(i); ch_subset = find(kcoords==kcoords3); - temp = find(ismember(spikeTimes(ia{i}), [ops.nt0/2:size(DATA,1)-ops.nt0/2]+ dat_offset)); + temp = find(ismember(spikeTimes(ia{i}), [ops.nt0/2+1:size(DATA,1)-ops.nt0/2] + dat_offset)); temp2 = spikeTimes(ia{i}(temp))-dat_offset; + startIndicies = temp2-ops.nt0/2+1; stopIndicies = temp2+ops.nt0/2; - channel_indexes = ch_subset(indicesTokeep{i}); -% waveforms_all{i}(:,:,temp) = arrayfun(@(i,j,k) DATA(i:j,k), startIndicies, stopIndicies, channel_indexes); - for ii = 1:length(temp) - waveforms_all{i}(:,:,temp(ii)) = DATA(temp2(ii)-ops.nt0/2+1:temp2(ii)+ops.nt0/2,ch_subset(indicesTokeep{i}))'; - end + X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1); + X = X(1:end-1); + waveforms_all{i}(:,:,temp) = reshape(DATA(X,ch_subset(indicesTokeep{i}))',length(ch_subset),ops.nt0,[]); end end - fprintf('\n Extraction of waveforms complete \n') end end From b88433390ad8f7981e97558e44125447e7be9b5f Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Wed, 16 May 2018 19:34:12 -0400 Subject: [PATCH 12/35] fixed bug related to shank indexing when skipping channels. fixed bug related to shank indexing when skipping channels. Skipped channels are now defined by SpikeGroup and not the skip command (compatible with Klusters). --- Kilosort2Neurosuite.m | 7 +++++-- createChannelMapFile_KSW.m | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/Kilosort2Neurosuite.m b/Kilosort2Neurosuite.m index c03d4f7..790ca13 100644 --- a/Kilosort2Neurosuite.m +++ b/Kilosort2Neurosuite.m @@ -199,13 +199,16 @@ function Kilosort2Neurosuite(rez) channel_order = {}; indicesTokeep = {}; + connected_index = zeros(size(rez.connected)); + connected_index(rez.connected)=1:length(chanMapConn); + for i = 1:length(kcoords2) kcoords3 = kcoords2(i); waveforms_all{i} = zeros(sum(kcoords==kcoords3),ops.nt0,size(rez.ia{i},1)); if exist('xml') - channel_order{i} = xml.AnatGrps(i).Channels+1; + channel_order = xml.AnatGrps(i).Channels+1; ch_subset = find(kcoords==kcoords3); - [~,indicesTokeep{i},~] = intersect(channel_order{i},ch_subset); + [~,indicesTokeep{i},~] = intersect(connected_index(channel_order),ch_subset); [~,indicesTokeep{i}] = sort(indicesTokeep{i}); end end diff --git a/createChannelMapFile_KSW.m b/createChannelMapFile_KSW.m index b7a1acd..2db5d4f 100644 --- a/createChannelMapFile_KSW.m +++ b/createChannelMapFile_KSW.m @@ -127,9 +127,13 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) connected = true(Nchannels, 1); % Removing dead channels by the skip parameter in the xml +% order = [par.AnatGrps.Channels]; +% skip = find([par.AnatGrps.Skip]); +% connected(order(skip)+1) = false; + order = [par.AnatGrps.Channels]; -skip = find([par.AnatGrps.Skip]); -connected(order(skip)+1) = false; +skip2 = find(~ismember([par.AnatGrps.Channels], [par.SpkGrps.Channels])); % finds the indices of the channels that are not part of SpkGrps +connected(order(skip2)+1) = false; chanMap = 1:Nchannels; chanMap0ind = chanMap - 1; From 7fe65b3fb7ca69bf0ead82311ab121119a9883d7 Mon Sep 17 00:00:00 2001 From: Peter Petersen Date: Thu, 17 May 2018 14:25:11 -0400 Subject: [PATCH 13/35] Update README.md --- README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 01b62a0..aa64c14 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,22 @@ # KilosortWrapper -Allows you to load Kilosort from a .xml and a .dat file compatible with Neurosuite +Allows you to load Kilosort from a .xml and a .dat file compatible with Neurosuite. ## Installation -Download and add the KilosortWrapper directory to your matlab path. +Download and add the KilosortWrapper directory to your Matlab path. ## Settings Most settings are defined in the KilosortConfiguration.m file. Some general settings are defined in the KilosortWrapper file, including: * Path to SSD -* Process in subdirectory +* CreateSubdirectory: Allows you to save the output files from Kilosort to a sub directory (labeled by data and time). -Supply a config version input, to use another config file (configuration files should be stored in the ConfigurationFiles folder). +You can supply a config version input to use another config file (configuration files should be stored in the ConfigurationFiles folder). ## Features -Skip channels: To skip dead channels, select the skip function in Neuroscope or NDManager -Define probe layouts: The wrapper now supports probes with staggered, poly 3 and poly 5 probe layouts... -Allows you to save the output from Kilosort to a sub directory. +Skip channels: To skip dead channels, synchronize the anatomical groups and the spike groups in Neuroscope and remove the dead channels in the spike groups. The synchronization is necessary for maintaining the correct waveform layout in Phy. +Define probe layouts: The wrapper now supports probes with staggered, poly3 and poly5 probe layouts. Open your xml file and define your probe layout in the Notes field (General information). Kilosort assumes a staggered probe layout without any input. + +CreateSubdirectory: Allows you to save the output files from Kilosort to a sub directory (labeled by data and time). ## Outputs The Kilosort wrapper allows you to save the output in Neurosuite and Phy compatible files. From 473b6ebd253f7cd6739d0b8779c84f0f9ef928fa Mon Sep 17 00:00:00 2001 From: Brendon Watson Date: Fri, 8 Jun 2018 01:33:32 -0400 Subject: [PATCH 14/35] Fix to spk writing code in Neurosuite output Error in case of spike groups without a template in the subfunction Kilosort2Neurosute/Kilosort_ExtractWaveforms Fixed by only calling to extract values for spike groups that have at least one template. This follows code for kcoords2 towards top of Kilosort2Neurosute.m itself --- Kilosort2Neurosuite.m | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Kilosort2Neurosuite.m b/Kilosort2Neurosuite.m index 07cb08f..8c29f2b 100644 --- a/Kilosort2Neurosuite.m +++ b/Kilosort2Neurosuite.m @@ -188,8 +188,10 @@ function Kilosort2Neurosuite(rez) fid = fopen(ops.fbinary, 'r'); waveforms_all = []; - kcoords2 = unique(ops.kcoords); - +% kcoords2 = unique(ops.kcoords); + template_kcoords = kcoords(amplitude_max_channel); + kcoords2 = unique(template_kcoords); + channel_order = {}; indicesTokeep = {}; for i = 1:length(kcoords2) From edf1d9861ceacf3c875db4849f415c36240f55e1 Mon Sep 17 00:00:00 2001 From: Brendon Watson Date: Tue, 19 Jun 2018 17:43:45 -0400 Subject: [PATCH 15/35] Fixing a few errors that rarely crop up Put in a couple of "if" statement to handle cases --- KiloSortWrapper.m | 3 ++- createChannelMapFile_KSW.m | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 5023b66..de70d0d 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -62,7 +62,8 @@ % if exist(fullfile(basepath,'StandardConfig.m'),'file') %this should actually be unnecessary % addpath(basepath); % end -if ~exist('config') +ec = exist('config'); +if ec ~= 1 disp('Running Kilosort with standard settings') ops = KilosortConfiguration(XMLFilePath); else diff --git a/createChannelMapFile_KSW.m b/createChannelMapFile_KSW.m index 2db5d4f..dc7c1e7 100644 --- a/createChannelMapFile_KSW.m +++ b/createChannelMapFile_KSW.m @@ -132,8 +132,10 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) % connected(order(skip)+1) = false; order = [par.AnatGrps.Channels]; -skip2 = find(~ismember([par.AnatGrps.Channels], [par.SpkGrps.Channels])); % finds the indices of the channels that are not part of SpkGrps -connected(order(skip2)+1) = false; +if isfield(par,'SpkGrps') + skip2 = find(~ismember([par.AnatGrps.Channels], [par.SpkGrps.Channels])); % finds the indices of the channels that are not part of SpkGrps + connected(order(skip2)+1) = false; +end chanMap = 1:Nchannels; chanMap0ind = chanMap - 1; From e4b367786c4d85d65e2704f9b0866035b6132eb3 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Thu, 4 Oct 2018 17:13:10 -0400 Subject: [PATCH 16/35] Updates to handle inputs and Export function --- .../KilosortConfiguration_Cesar.m | 92 ++++++ KiloSortWrapper.m | 40 ++- Kilosort2Neurosuite.m | 20 +- KilosortConfiguration.m | 8 +- Phy2Neurosuite.m | 278 ++++++++++++++++++ private/KiloSortLinuxDir.m | 91 ++++++ 6 files changed, 510 insertions(+), 19 deletions(-) create mode 100644 ConfigurationFiles/KilosortConfiguration_Cesar.m create mode 100644 Phy2Neurosuite.m create mode 100644 private/KiloSortLinuxDir.m diff --git a/ConfigurationFiles/KilosortConfiguration_Cesar.m b/ConfigurationFiles/KilosortConfiguration_Cesar.m new file mode 100644 index 0000000..95e015d --- /dev/null +++ b/ConfigurationFiles/KilosortConfiguration_Cesar.m @@ -0,0 +1,92 @@ +function ops = StandardConfig_KSW(XMLfile) + +% Loads xml parameters (Neuroscope) +xml = LoadXml(XMLfile); +% Define rootpath +rootpath = fileparts(XMLfile); + +ops.GPU = 1; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first) +ops.parfor = 1; % whether to use parfor to accelerate some parts of the algorithm +ops.verbose = 1; % whether to print command line progress +ops.showfigures = 0; % whether to plot figures during optimization +ops.datatype = 'dat'; % binary ('dat', 'bin') or 'openEphys' +ops.fbinary = [XMLfile(1:end-3) 'dat']; % will be created for 'openEphys' + +%Should get rid of this... +if isdir('G:\Kilosort') + disp('Creating a temporary dat file on the SSD drive') + ops.fproc = ['G:\Kilosort\temp_wh.dat']; +else + ops.fproc = fullfile(rootpath,'temp_wh.dat'); +end +ops.root = rootpath; % 'openEphys' only: where raw files are +ops.fs = xml.SampleRate; % sampling rate + +load(fullfile(rootpath,'chanMap.mat')) +ops.NchanTOT = length(connected); % total number of channels + +ops.Nchan = sum(connected>1e-6); % number of active channels + +templatemultiplier = 8; +ops.Nfilt = ops.Nchan*templatemultiplier - mod(ops.Nchan*templatemultiplier,32); % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) +% if ops.Nfilt > 2024; +% ops.Nfilt = 2024; +% elseif ops.Nfilt == 0 +% ops.Nfilt = 32; +% end +ops.nt0 = round(1.6*ops.fs/1000); % window width in samples. 1.6ms at 20kH corresponds to 32 samples + +ops.nNeighPC = min([16 ops.Nchan]); % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12) +ops.nNeigh = min([16 ops.Nchan]); % visualization only (Phy): number of neighboring templates to retain projections of (16) + +% options for channel whitening +ops.whitening = 'full'; % type of whitening (default 'full', for 'noSpikes' set options for spike detection below) +ops.nSkipCov = 1; % compute whitening matrix from every N-th batch (1) +ops.whiteningRange = min([64 ops.Nchan]); % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32) + +% define the channel map as a filename (string) or simply an array +ops.chanMap = fullfile(rootpath,'chanMap.mat'); % make this file using createChannelMapFile.m +ops.criterionNoiseChannels = 0.00001; % fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info). + +% other options for controlling the model and optimization +ops.Nrank = 3; % matrix rank of spike template model (3) +ops.nfullpasses = 6; % number of complete passes through data during optimization (6) +ops.maxFR = 40000; % maximum number of spikes to extract per batch (20000) +ops.fshigh = 500; % frequency for high pass filtering +ops.fslow = 8000; % frequency for low pass filtering (optional) +ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection +ops.scaleproc = 200; % int16 scaling of whitened data +ops.NT = 4*32*1028+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) for GPU should be multiple of 32 + ntbuff + +% the following options can improve/deteriorate results. +% when multiple values are provided for an option, the first two are beginning and ending anneal values, +% the third is the value used in the final pass. +ops.Th = [5 6 6]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.lam = [10 20 20]; % large means amplitudes are forced around the mean ([10 30 30]) +ops.nannealpasses = 4; % should be less than nfullpasses (4) +ops.momentum = 1./[20 800]; % start with high momentum and anneal (1./[20 1000]) +ops.shuffle_clusters = 1; % allow merges and splits during optimization (1) +ops.mergeT = .1; % upper threshold for merging (.1) +ops.splitT = .1; % lower threshold for splitting (.1) + +% options for initializing spikes from data +ops.initialize = 'no'; %'fromData' or 'no' +ops.spkTh = 4; % spike threshold in standard deviations (4) +ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) +ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) +ops.maskMaxChannels = 8; % how many channels to mask up/down ([5]) +ops.crit = .65; % upper criterion for discarding spike repeates (0.65) +ops.nFiltMax = 10000; % maximum "unique" spikes to consider (10000) + +% load predefined principal components (visualization only (Phy): used for features) +dd = load('PCspikes2.mat'); % you might want to recompute this from your own data +ops.wPCA = dd.Wi(:,1:7); % PCs + +% options for posthoc merges (under construction) +ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd) +ops.epu = Inf; +ops.ForceMaxRAMforDat = 15000000000; % maximum RAM the algorithm will try to use; on Windows it will autodetect. + +% Saving xml content to ops strucuture +ops.xml = xml; +end diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 5023b66..87e4c05 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -64,20 +64,35 @@ % end if ~exist('config') disp('Running Kilosort with standard settings') - ops = KilosortConfiguration(XMLFilePath); + ops = KiloSortConfiguration(XMLFilePath); else disp('Running Kilosort with user specific settings') - config_string = str2func(['KilosortConfiguration_' config_version]); + config_string = str2func(['KiloSortConfiguration_' config_version]); ops = config_string(XMLFilePath); clear config_string; end %% % Defining SSD location if any -if isdir('G:\Kilosort') - disp('Creating a temporary dat file on the SSD drive') - ops.fproc = ['G:\Kilosort\temp_wh.dat']; +SSD_path = 'G:\Kilosort'; + +if isunix + fname = KiloSortLinuxDir(basename,basepath,gpuDeviceNum); + ops.fproc = fname; else - ops.fproc = fullfile(basepath,'temp_wh.dat'); + if isdir(SSD_path) + FileObj = java.io.File(SSD_path); + free_bytes = FileObj.getFreeSpace; + dat_file = dir(fullfile(basepath,[basename,'.dat'])); + if dat_file.bytes*1.11e-6); @@ -199,17 +203,17 @@ function Kilosort2Neurosuite(rez) channel_order = {}; indicesTokeep = {}; - connected_index = zeros(size(rez.connected)); - connected_index(rez.connected)=1:length(chanMapConn); +% connected_index = zeros(size(rez.connected)); +% connected_index(rez.connected)=1:length(chanMapConn); for i = 1:length(kcoords2) kcoords3 = kcoords2(i); waveforms_all{i} = zeros(sum(kcoords==kcoords3),ops.nt0,size(rez.ia{i},1)); if exist('xml') - channel_order = xml.AnatGrps(i).Channels+1; - ch_subset = find(kcoords==kcoords3); - [~,indicesTokeep{i},~] = intersect(connected_index(channel_order),ch_subset); - [~,indicesTokeep{i}] = sort(indicesTokeep{i}); + [channel_order,channel_index] = sort(xml.AnatGrps(kcoords2(i)).Channels+1); + [~,indicesTokeep{i},~] = intersect(chanMapConn,channel_order); + + %indicesTokeep{i} = connected_index(indicesTokeep{i}); end end @@ -259,7 +263,7 @@ function Kilosort2Neurosuite(rez) % Saves the waveforms occuring within each batch for i = 1:length(kcoords2) kcoords3 = kcoords2(i); - ch_subset = find(kcoords==kcoords3); +% ch_subset = 1:length(chanMapConn); temp = find(ismember(spikeTimes(ia{i}), [ops.nt0/2+1:size(DATA,1)-ops.nt0/2] + dat_offset)); temp2 = spikeTimes(ia{i}(temp))-dat_offset; @@ -267,7 +271,7 @@ function Kilosort2Neurosuite(rez) stopIndicies = temp2+ops.nt0/2; X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1); X = X(1:end-1); - waveforms_all{i}(:,:,temp) = reshape(DATA(X,ch_subset(indicesTokeep{i}))',length(ch_subset),ops.nt0,[]); + waveforms_all{i}(:,:,temp) = reshape(DATA(X,indicesTokeep{i})',size(indicesTokeep{i},1),ops.nt0,[]); end end end diff --git a/KilosortConfiguration.m b/KilosortConfiguration.m index fd5a669..8985fde 100644 --- a/KilosortConfiguration.m +++ b/KilosortConfiguration.m @@ -45,12 +45,12 @@ ops.fslow = 8000; % frequency for low pass filtering (optional) ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection ops.scaleproc = 200; % int16 scaling of whitened data -ops.NT = 4*32*1028+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) for GPU should be multiple of 32 + ntbuff +ops.NT = 32*1028+ ops.ntbuff;% this is the batch size (try decreasing if out of memory) for GPU should be multiple of 32 + ntbuff % the following options can improve/deteriorate results. % when multiple values are provided for an option, the first two are beginning and ending anneal values, % the third is the value used in the final pass. -ops.Th = [6 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.Th = [6 12 12]; % threshold for detecting spikes on template-filtered data ([6 12 12]) ops.lam = [12 40 40]; % large means amplitudes are forced around the mean ([10 30 30]) ops.nannealpasses = 4; % should be less than nfullpasses (4) ops.momentum = 1./[20 800]; % start with high momentum and anneal (1./[20 1000]) @@ -59,7 +59,7 @@ ops.splitT = .1; % lower threshold for splitting (.1) % options for initializing spikes from data -ops.initialize = 'fromData'; %'fromData' or 'no' +ops.initialize = 'no'; %'fromData' or 'no' ops.spkTh = -4; % spike threshold in standard deviations (4) ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) @@ -81,6 +81,6 @@ % Specify if the output should be exported to Phy and/or Neurosuite ops.export.phy = 1; -ops.export.neurosuite = 1; +ops.export.neurosuite = 0; end diff --git a/Phy2Neurosuite.m b/Phy2Neurosuite.m new file mode 100644 index 0000000..d597a46 --- /dev/null +++ b/Phy2Neurosuite.m @@ -0,0 +1,278 @@ +function Phy2Neurosuite(basepath,clustering_path) +% Converts Phy output (NPY files) to Neurosuite files: fet, res, clu, spk files. +% Based on the GPU enable filter from Kilosort and fractions from Brendon +% Watson's code for saving Neurosuite files. + +% The script has a high memory usage as all waveforms are loaded into +% memory at the same time. If you experience a memory error, increase +% your swap/cashe file, and increase the amount of memory MATLAB can use. +% +% 1) Waveforms are extracted from the dat file via GPU enabled filters. +% 2) Features are calculated in parfor loops. +% +% Inputs: +% path - rez structure from Kilosort +% +% By Peter Petersen 2018 +% petersen.peter@gmail.com + +t1 = tic; +cd(clustering_path) +load('rez.mat') +rez.ops.root = clustering_path; +basename = rez.ops.basename; +rez.ops.fbinary = fullfile(basepath, [basename,'.dat']); +rez.ops.fshigh = 500; + +spikeTimes = uint64(rez.st3(:,1)); % uint64 +spikeTemplates = double(readNPY(fullfile(clustering_path, 'spike_clusters.npy'))); +spike_clusters = unique(spikeTemplates); +cluster_ids = readNPY(fullfile(clustering_path, 'cluster_ids.npy')); +template_kcoords = readNPY(fullfile(clustering_path, 'shanks.npy')); + +kcoords = rez.ops.kcoords; + +Nchan = rez.ops.Nchan; +samples = rez.ops.nt0; + +kcoords2 = unique(template_kcoords); +ia = []; +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Loading data for spike group ', num2str(kcoords3),'. ']) + template_index = cluster_ids(find(template_kcoords == kcoords3)); + ia{i} = find(ismember(spikeTemplates,template_index)); +end +rez.ia = ia; +fprintf('\n'); toc(t1) + +fprintf('\nSaving .clu files to disk (cluster indexes)') +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Saving .clu file for group ', num2str(kcoords3),'. ']) + tclu = spikeTemplates(ia{i}); + tclu = cat(1,length(unique(tclu)),double(tclu)); + cluname = fullfile([basename '.clu.' num2str(kcoords3)]); + fid=fopen(cluname,'w'); + fprintf(fid,'%.0f\n',tclu); + fclose(fid); + clear fid +end +fprintf('\n'); toc(t1) + +fprintf('\nSaving .res files to disk (spike times)') +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + tspktimes = spikeTimes(ia{i}); + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Saving .res file for group ', num2str(kcoords3),'. ']) + resname = fullfile([basename '.res.' num2str(kcoords3)]); + fid=fopen(resname,'w'); + fprintf(fid,'%.0f\n',tspktimes); + fclose(fid); + clear fid +end +fprintf('\n'); toc(t1) + +fprintf('\nExtracting waveforms\n') +waveforms_all = Kilosort_ExtractWaveforms(rez); +fprintf('\n'); toc(t1) + +fprintf('\nSaving .spk files to disk (waveforms)') +for i = 1:length(kcoords2) + if mod(i,4)==1; fprintf('\n'); end + fprintf(['Saving .spk for group ', num2str(kcoords2(i)),'. ']) + fid=fopen([basename,'.spk.',num2str(kcoords2(i))],'w'); + fwrite(fid,waveforms_all{i}(:),'int16'); + fclose(fid); +end +fprintf('\n'); toc(t1) + +fprintf('\nComputing PCAs') +% Starting parpool if stated in the Kilosort settings +if (rez.ops.parfor & isempty(gcp('nocreate'))); parpool; end + +for i = 1:length(kcoords2) + kcoords3 = kcoords2(i); + if mod(i,2)==1; fprintf('\n'); end + fprintf(['Computing PCAs for group ', num2str(kcoords3),'. ']) + PCAs_global = zeros(3,sum(kcoords==kcoords3),length(ia{i})); + waveforms = waveforms_all{i}; + + waveforms2 = reshape(waveforms,[size(waveforms,1)*size(waveforms,2),size(waveforms,3)]); + wranges = int64(range(waveforms2,1)); + wpowers = int64(sum(waveforms2.^2,1)/size(waveforms2,1)/100); + + % Calculating PCAs in parallel if stated in ops.parfor + if isempty(gcp('nocreate')) + for k = 1:size(waveforms,1) + PCAs_global(:,k,:) = pca(zscore(permute(waveforms(k,:,:),[2,3,1]),[],2),'NumComponents',3)'; + end + else + parfor k = 1:size(waveforms,1) + PCAs_global(:,k,:) = pca(zscore(permute(waveforms(k,:,:),[2,3,1]),[],2),'NumComponents',3)'; + end + end + fprintf(['Saving .fet files for group ', num2str(kcoords3),'. ']) + PCAs_global2 = reshape(PCAs_global,size(PCAs_global,1)*size(PCAs_global,2),size(PCAs_global,3)); + factor = (2^15)./max(abs(PCAs_global2')); + PCAs_global2 = int64(PCAs_global2 .* factor'); + + fid=fopen([basename,'.fet.',num2str(kcoords3)],'w'); + Fet = double([PCAs_global2; wranges; wpowers; spikeTimes(ia{i})']); + nFeatures = size(Fet, 1); + formatstring = '%d'; + for ii=2:nFeatures + formatstring = [formatstring,'\t%d']; + end + formatstring = [formatstring,'\n']; + + fprintf(fid, '%d\n', nFeatures); + fprintf(fid,formatstring,Fet); + fclose(fid); +end +fprintf('\n'); toc(t1) +fprintf('\nComplete!') + + function waveforms_all = Kilosort_ExtractWaveforms(rez) + % Extracts waveforms from a dat file using GPU enable filters. + % Based on the GPU enable filter from Kilosort. + % All settings and content are extracted from the rez input structure + % + % Inputs: + % rez - rez structure from Kilosort + % + % Outputs: + % waveforms_all - structure with extracted waveforms + + % Extracting content from the .rez file + ops = rez.ops; + NT = ops.NT; + if exist('ops.fbinary') == 0 + warning(['Binary file does not exist: ', ops.fbinary]) + end + d = dir(ops.fbinary); + + NchanTOT = ops.NchanTOT; + chanMap = ops.chanMap; + chanMapConn = chanMap(rez.connected>1e-6); + kcoords = ops.kcoords; + ia = rez.ia; + spikeTimes = rez.st3(:,1); + + if ispc + dmem = memory; + memfree = dmem.MemAvailableAllArrays/8; + memallocated = min(ops.ForceMaxRAMforDat, dmem.MemAvailableAllArrays) - memfree; + memallocated = max(0, memallocated); + else + memallocated = ops.ForceMaxRAMforDat; + end + ops.ForceMaxRAMforDat = 10000000000; + memallocated = ops.ForceMaxRAMforDat; + nint16s = memallocated/2; + + NTbuff = NT + 4*ops.ntbuff; + Nbatch = ceil(d.bytes/2/NchanTOT /(NT-ops.ntbuff)); + Nbatch_buff = floor(4/5 * nint16s/ops.Nchan /(NT-ops.ntbuff)); % factor of 4/5 for storing PCs of spikes + Nbatch_buff = min(Nbatch_buff, Nbatch); + + DATA =zeros(NT, NchanTOT,Nbatch_buff,'int16'); + + if isfield(ops,'fslow')&&ops.fslow .5) + %save 500MB on the SSD, can be decreased + + [~,b] = max(freespaceSSD-datsize); + + fname = [mountSSD{b} '/temp_wh_' num2str(gpuDeviceNum) '.dat']; +elseif any( (freespaceHD-datsize) > .5) + [~,b] = max(freespaceSSD-datsize); + + fname = [mountHD{b} '/temp_wh_' num2str(gpuDeviceNum) '.dat']; +else + error('NO DISK SPACE') + +end + + + + + + +end \ No newline at end of file From 5ffe0a8b5c5b0934bd1501ffd3f5cfc72941b128 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Mon, 26 Nov 2018 16:52:57 -0500 Subject: [PATCH 17/35] Updates --- KiloSortWrapper.m | 2 +- Phy2Neurosuite.m | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 87e4c05..34cd912 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -73,7 +73,7 @@ end %% % Defining SSD location if any -SSD_path = 'G:\Kilosort'; +SSD_path = 'K:\Kilosort'; if isunix fname = KiloSortLinuxDir(basename,basepath,gpuDeviceNum); diff --git a/Phy2Neurosuite.m b/Phy2Neurosuite.m index d597a46..af5813e 100644 --- a/Phy2Neurosuite.m +++ b/Phy2Neurosuite.m @@ -18,13 +18,26 @@ function Phy2Neurosuite(basepath,clustering_path) t1 = tic; cd(clustering_path) -load('rez.mat') +if exist('rez.mat') + load('rez.mat') + spikeTimes = uint64(rez.st3(:,1)); % uint64 + basename = rez.ops.basename; +elseif exist(fullfile(basepath,'ops.mat')) + rez = []; + load(fullfile(basepath,'ops.mat')) + rez.ops = ops; + spikeTimes = readNPY(fullfile(clustering_path, 'spike_times.npy')); +% load(fullfile(basepath,'chanMap.mat')) + rez.connected = ones(1,ops.NchanTOT); + basename = bz_BasenameFromBasepath(basepath); +else + disp('No rez.mat or ops.mat file exist!') +end + rez.ops.root = clustering_path; -basename = rez.ops.basename; rez.ops.fbinary = fullfile(basepath, [basename,'.dat']); rez.ops.fshigh = 500; -spikeTimes = uint64(rez.st3(:,1)); % uint64 spikeTemplates = double(readNPY(fullfile(clustering_path, 'spike_clusters.npy'))); spike_clusters = unique(spikeTemplates); cluster_ids = readNPY(fullfile(clustering_path, 'cluster_ids.npy')); From afc5422489423460dabfea60d8138f480692d89f Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Mon, 17 Dec 2018 15:20:02 -0500 Subject: [PATCH 18/35] Autoclustering script for phy --- KilosortConfiguration.m | 4 +- Phy2Neurosuite.m | 2 +- PhyAutoClustering.m | 219 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 222 insertions(+), 3 deletions(-) create mode 100644 PhyAutoClustering.m diff --git a/KilosortConfiguration.m b/KilosortConfiguration.m index 8985fde..01a3524 100644 --- a/KilosortConfiguration.m +++ b/KilosortConfiguration.m @@ -20,7 +20,7 @@ ops.Nchan = sum(connected>1e-6); % number of active channels -templatemultiplier = 8; +templatemultiplier = 8; % 8 times more templates created than Nchan ops.Nfilt = ops.Nchan*templatemultiplier - mod(ops.Nchan*templatemultiplier,32); % number of filters to use (2-4 times more than Nchan, should be a multiple of 32) ops.nt0 = round(1.6*ops.fs/1000); % window width in samples. 1.6ms at 20kH corresponds to 32 samples @@ -50,7 +50,7 @@ % the following options can improve/deteriorate results. % when multiple values are provided for an option, the first two are beginning and ending anneal values, % the third is the value used in the final pass. -ops.Th = [6 12 12]; % threshold for detecting spikes on template-filtered data ([6 12 12]) +ops.Th = [6 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12]) ops.lam = [12 40 40]; % large means amplitudes are forced around the mean ([10 30 30]) ops.nannealpasses = 4; % should be less than nfullpasses (4) ops.momentum = 1./[20 800]; % start with high momentum and anneal (1./[20 1000]) diff --git a/Phy2Neurosuite.m b/Phy2Neurosuite.m index af5813e..371e892 100644 --- a/Phy2Neurosuite.m +++ b/Phy2Neurosuite.m @@ -163,7 +163,7 @@ function Phy2Neurosuite(basepath,clustering_path) % Extracting content from the .rez file ops = rez.ops; NT = ops.NT; - if exist('ops.fbinary') == 0 + if exist(ops.fbinary) == 0 warning(['Binary file does not exist: ', ops.fbinary]) end d = dir(ops.fbinary); diff --git a/PhyAutoClustering.m b/PhyAutoClustering.m new file mode 100644 index 0000000..a2d70a1 --- /dev/null +++ b/PhyAutoClustering.m @@ -0,0 +1,219 @@ +function PhyAutoClustering(clusteringpath,varargin) +% AutoClustering automtically cleans Kilosort output in phy format defined +% by a clusteringpath. +% +% INPUT: +% clusteringpath: char +% +% optional: +% AutoClustering(clusteringpath,elec,dim) +% where dim is the number of channels in electro group (if not +% defined, will read the first line of the fet file +% +% AutoClustering is meant to clean the output of KlustaKwik. The first +% thing it does is to separate electrical artifacts and MUA from putative +% isolated units. To do so, it sorts out units which have no clear +% refractory period (based on Hill, Mehta and Kleinfeld, J Neurosci., +% 2012). Threshold can be set in the parameter section of this file +% ("Rogue spike threshold"). Then, it separates electrical +% artifats from MUA based on the assumption that electrical artifacts are +% highly correlated on the different channels: the average waveform of at +% least one channel has to be different from the across-channel average +% waveform by a certrain amount of total variance (can be set in the +% parameter section, "Deviation from average spike threshold") +% +% +% Once the program has determined which of the clusters are putative +% isolated units, it tries to merge them based on waveform similarity +% (mahalanobis distance) and quality of the refractory period in the new +% merged cluster (or "Inter Common Spike Interval" from MS Fee et al. J +% Neurosci. Meth., 1996) +% +% Original script by Adrien Peyrache, 2012. +% Many modifications for Phy processing pipeline by +% Yuta Senzai and Peter Petersen + + +% if ~isempty(varargin) +% dim = varargin{1}; +% dim = dim(:); +% if any(double(int16(dim))~=dim) +% error('Number of dimensions must be an integer') +% end +% +% if size(dim,1) ~= length(elec) && length(dim) ~=1 +% error('Number of dimensions must be a vector of the same length as electrode vector or a single value') +% end +% if length(dim) == 1 +% dim = dim*ones(length(elec),1); +% end +% else +% dim = zeros(length(elec),1); +% end + +% Refractory period in msec +tR = 1.5; % 1.5 +% Censored period in msec (specific to the spike detection process) +tC = 0.85; +% Rogue spike threshold (for MUA); value between 0 an 1 +%rogThres = 0.25; +rogThres = 0.33; + +% Relative deviation (from 0 to 1) from average spike threshold (for electrical artifacts) +%devThres = 0.25; +% =1000 => bypass it +% devThres = 1000; +rThres = 0.7; +mprThres = 2; + +% Artifact removal threshold +amplitude_thr = 50; +mahal_thr = 18; + +% Load spike timing +cd(clusteringpath) +dirname = ['PhyAutoClustering_', datestr(clock,'yyyy-mm-dd_HHMMSS')]; +mkdir(dirname) +copyfile(fullfile(clusteringpath, 'spike_clusters.npy'), fullfile(clusteringpath, dirname,'spike_clusters.npy')) +if exist(fullfile(clusteringpath, 'cluster_group.tsv')) + copyfile(fullfile(clusteringpath, 'cluster_group.tsv'), fullfile(clusteringpath, dirname, 'cluster_group.tsv')) +end + +clu = readNPY(fullfile(clusteringpath, 'spike_clusters.npy')); +clu = double(clu); +cids = unique(clu); + +wav_all_orig = readNPY(fullfile(clusteringpath,'templates.npy')); +wav_all_orig2 = permute(wav_all_orig,[2,3,1]); +ch_indx = []; +for i = 1:size(wav_all_orig2,3) + [~,ch_indx(i)] = max(max(wav_all_orig2(:,:,i))-min(wav_all_orig2(:,:,i))); +end +channel_shanks = readNPY(fullfile(clusteringpath, 'channel_shanks.npy')); + +ch_indx2 = {}; +for j = unique(channel_shanks) + ch_indx2{j} = find(channel_shanks == j); +end + +% Removing spikes with large artifacts +disp('Removing spikes with large artifacts') +spike_amplitudes = readNPY(fullfile(clusteringpath, 'amplitudes.npy')); + +spike_amplitudes = nanconv(spike_amplitudes',ones(1,20),'edge'); +indx = find(spike_amplitudes>amplitude_thr); +disp([num2str(length(indx)),' artifacts detected']) + +clu2 = clu; +artifact_clusters = []; +for j = unique(channel_shanks) + indx22 = find(ismember(ch_indx(clu(indx)+1),ch_indx2{j})); + clu2(indx(indx22)) = max(clu)+j; + artifact_clusters = [artifact_clusters,max(clu)+j]; +end +clu = clu2; +spike_PCAs = double(readNPY(fullfile(clusteringpath, 'pc_features.npy'))); + +% Mahal artifact removal +disp('Removing outliers by Mahalanobis theshold...') +spike_clusters = clu; +mahal_outlier_clusters = []; +spikes_removed = 0; +for i = 1:length(cids) + cluster_id = cids(i); + indexes = find(spike_clusters==cluster_id); + if length(indexes)>100 + indexes1 = spike_PCAs(indexes,:,:); + indexes2 = reshape(indexes1,[size(indexes1,1),size(indexes1,2)*size(indexes1,3)]); + test2 = mahal(indexes2,indexes2); + test3 = find(test2>mahal_thr^2); + mahal_outlier_clusters = [mahal_outlier_clusters,max(spike_clusters)+1]; + spikes_removed = spikes_removed+length(test3); + spike_clusters(indexes(test3)) = mahal_outlier_clusters(end); + end +end +clu = spike_clusters; +disp([num2str(length(mahal_outlier_clusters)),' units cleaned by Mahal outlier detection. Spikes removed: ',num2str(spikes_removed)]) +writeNPY(uint32(clu), fullfile(clusteringpath, 'spike_clusters.npy')); + +% Loading rez.mat for sampling rate +disp('Loading rez.mat') +load(fullfile(clusteringpath,'rez.mat')) +sr = rez.ops.fs; + +res_int = readNPY(fullfile(clusteringpath,'spike_times.npy')); +res = double(res_int)/sr; + +wav_all = wav_all_orig2; + +disp('Classifying noise/mua') +meanR = []; +fractRogue = []; +maxPwRatio = []; +for ii=1:length(cids) + spktime = res(clu==cids(ii)); + if ~isempty(spktime) + % dim = channel_shanks(ch_indx(cids(ii)+1)); + wav = squeeze(wav_all(13:end,:,ii)); + wav = wav(:,find(any(wav))); + dim = size(wav,2); + + [R,~] = corrcoef(wav); + meanR_cur = (sum(sum(R)) - dim) /(dim*(dim-1)); + meanR = [meanR; meanR_cur]; + + maxPwRatio_cur = max(abs(wav(11,:)))/mean(abs(wav(11,:))); + maxPwRatio = [maxPwRatio; maxPwRatio_cur]; + + [ccgR,t] = CCG(spktime,ones(size(spktime)),'binsize',.0005,'duration',.06); + indx3 = find(t > -0.0015 & t < 0.0015); + spkRef = mean(ccgR(indx3)); % refractory period: -1.5ms to 1.5ms + spkMean = mean(ccgR(round(indx3(1)/2):indx3(1)-1)); + % l = FractionRogueSpk(spktime,tR,tC); + l = spkRef/spkMean; + fractRogue = [fractRogue;l]; + else + maxPwRatio = [maxPwRatio; 0]; + meanR = [meanR; 0]; + fractRogue = [fractRogue;0]; + end +end + +% Here we compute # of spike per cell. Some code for the errormatrix fails +% when the cluster is defined by only a few samples. We'll put a +% threshopld a bit later on the total # of spikes. +h = hist(clu,unique(clu)); +h = h(:); +h = h(1:length(meanR)); +% Definition of cluster 0 (noiseIx) and cluster 1 (muaIx) +% Outliers of total spike power (putative electrical artifacts) not imlemented yet +noiseIx = find((meanR >= rThres & maxPwRatio < mprThres)|h<100); +muaIx = find(fractRogue>rogThres & ~(meanR >= rThres & maxPwRatio < mprThres) & h>=100); +goodIx = find(fractRogue<=rogThres & ~(meanR >= rThres & maxPwRatio < mprThres) & h>=100); % 100 or samlenum + +% Saving clusters to cluster_group.tsv +fid = fopen(fullfile(clusteringpath,'cluster_group.tsv'),'w'); +fwrite(fid, sprintf('cluster_id\t%s\r\n', 'group')); +for ii=1:length(cids) + if any(clu==cids(ii)) + if any(goodIx==ii) +% fwrite(fid, sprintf('%d\t%s\r\n', cids(ii), 'good')); + elseif any(muaIx==ii) + fwrite(fid, sprintf('%d\t%s\r\n', cids(ii), 'mua')); + elseif any(noiseIx==ii) + fwrite(fid, sprintf('%d\t%s\r\n', cids(ii), 'noise')); + end + end +end +for jj = 1:length(artifact_clusters) + fwrite(fid, sprintf('%d\t%s\r\n', artifact_clusters(jj), 'artifacts')); +end +mahal_outlier_clusters = unique(mahal_outlier_clusters); +for jj = 1:length(mahal_outlier_clusters) + fwrite(fid, sprintf('%d\t%s\r\n', mahal_outlier_clusters(jj), 'mua')); +end +fclose(fid); + +save(fullfile(clusteringpath,'autoclusta_params.mat'),'meanR','maxPwRatio','fractRogue','noiseIx','muaIx'); +disp('AutoClustering complete.') +end From 41ffe04c8ec15745fe1010bbd53322719bf44ed1 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Tue, 8 Jan 2019 15:30:50 -0500 Subject: [PATCH 19/35] bugfixes. --- KiloSortWrapper.m | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 1f91b28..c4d3630 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -65,7 +65,7 @@ ec = exist('config'); if ec ~= 1 disp('Running Kilosort with standard settings') - ops = KiloSortConfiguration(XMLFilePath); + ops = KilosortConfiguration(XMLFilePath); else disp('Running Kilosort with user specific settings') config_string = str2func(['KiloSortConfiguration_' config_version]); @@ -89,10 +89,10 @@ ops.fproc = fullfile(SSD_path, [basename,'_temp_wh.dat']); else warning('Not sufficient space on SSD drive. Creating local dat file instead') - ops.fproc = fullfile(rootpath,'temp_wh.dat'); + ops.fproc = fullfile(basepath,'temp_wh.dat'); end else - ops.fproc = fullfile(rootpath,'temp_wh.dat'); + ops.fproc = fullfile(basepath,'temp_wh.dat'); end end From fe511e1b54c145627ca2b61adfd700a1f2548956 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Mon, 21 Jan 2019 20:12:50 -0500 Subject: [PATCH 20/35] Create nanconv.m Added dependency for AutoClustering script --- private/nanconv.m | 114 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 private/nanconv.m diff --git a/private/nanconv.m b/private/nanconv.m new file mode 100644 index 0000000..d077433 --- /dev/null +++ b/private/nanconv.m @@ -0,0 +1,114 @@ +function c = nanconv(a, k, varargin) +% NANCONV Convolution in 1D or 2D ignoring NaNs. +% C = NANCONV(A, K) convolves A and K, correcting for any NaN values +% in the input vector A. The result is the same size as A (as though you +% called 'conv' or 'conv2' with the 'same' shape). +% +% C = NANCONV(A, K, 'param1', 'param2', ...) specifies one or more of the following: +% 'edge' - Apply edge correction to the output. +% 'noedge' - Do not apply edge correction to the output (default). +% 'nanout' - The result C should have NaNs in the same places as A. +% 'nonanout' - The result C should have ignored NaNs removed (default). +% Even with this option, C will have NaN values where the +% number of consecutive NaNs is too large to ignore. +% '2d' - Treat the input vectors as 2D matrices (default). +% '1d' - Treat the input vectors as 1D vectors. +% This option only matters if 'a' or 'k' is a row vector, +% and the other is a column vector. Otherwise, this +% option has no effect. +% +% NANCONV works by running 'conv2' either two or three times. The first +% time is run on the original input signals A and K, except all the +% NaN values in A are replaced with zeros. The 'same' input argument is +% used so the output is the same size as A. The second convolution is +% done between a matrix the same size as A, except with zeros wherever +% there is a NaN value in A, and ones everywhere else. The output from +% the first convolution is normalized by the output from the second +% convolution. This corrects for missing (NaN) values in A, but it has +% the side effect of correcting for edge effects due to the assumption of +% zero padding during convolution. When the optional 'noedge' parameter +% is included, the convolution is run a third time, this time on a matrix +% of all ones the same size as A. The output from this third convolution +% is used to restore the edge effects. The 'noedge' parameter is enabled +% by default so that the output from 'nanconv' is identical to the output +% from 'conv2' when the input argument A has no NaN values. +% +% See also conv, conv2 +% +% AUTHOR: Benjamin Kraus (bkraus@bu.edu, ben@benkraus.com) +% Copyright (c) 2013, Benjamin Kraus +% $Id: nanconv.m 4861 2013-05-27 03:16:22Z bkraus $ + +% Process input arguments +for arg = 1:nargin-2 + switch lower(varargin{arg}) + case 'edge'; edge = true; % Apply edge correction + case 'noedge'; edge = false; % Do not apply edge correction + case {'same','full','valid'}; shape = varargin{arg}; % Specify shape + case 'nanout'; nanout = true; % Include original NaNs in the output. + case 'nonanout'; nanout = false; % Do not include NaNs in the output. + case {'2d','is2d'}; is1D = false; % Treat the input as 2D + case {'1d','is1d'}; is1D = true; % Treat the input as 1D + end +end + +% Apply default options when necessary. +if(exist('edge','var')~=1); edge = false; end +if(exist('nanout','var')~=1); nanout = false; end +if(exist('is1D','var')~=1); is1D = false; end +if(exist('shape','var')~=1); shape = 'same'; +elseif(~strcmp(shape,'same')) + error([mfilename ':NotImplemented'],'Shape ''%s'' not implemented',shape); +end + +% Get the size of 'a' for use later. +sza = size(a); + +% If 1D, then convert them both to columns. +% This modification only matters if 'a' or 'k' is a row vector, and the +% other is a column vector. Otherwise, this argument has no effect. +if(is1D); + if(~isvector(a) || ~isvector(k)) + error('MATLAB:conv:AorBNotVector','A and B must be vectors.'); + end + a = a(:); k = k(:); +end + +% Flat function for comparison. +o = ones(size(a)); + +% Flat function with NaNs for comparison. +on = ones(size(a)); + +% Find all the NaNs in the input. +n = isnan(a); + +% Replace NaNs with zero, both in 'a' and 'on'. +a(n) = 0; +on(n) = 0; + +% Check that the filter does not have NaNs. +if(any(isnan(k))); + error([mfilename ':NaNinFilter'],'Filter (k) contains NaN values.'); +end + +% Calculate what a 'flat' function looks like after convolution. +if(any(n(:)) || edge) + flat = conv2(on,k,shape); +else flat = o; +end + +% The line above will automatically include a correction for edge effects, +% so remove that correction if the user does not want it. +if(any(n(:)) && ~edge); flat = flat./conv2(o,k,shape); end + +% Do the actual convolution +c = conv2(a,k,shape)./flat; + +% If requested, replace output values with NaNs corresponding to input. +if(nanout); c(n) = NaN; end + +% If 1D, convert back to the original shape. +if(is1D && sza(1) == 1); c = c.'; end + +end From 351c94a7a02edbd51f73e05050b44e37b175761b Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Mon, 18 Feb 2019 08:36:40 -0500 Subject: [PATCH 21/35] Bug fixes related to linux specific code --- KiloSortWrapper.m | 29 ++++++++++++----------------- Phy2Neurosuite.m | 32 +++++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index c4d3630..44a2440 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -73,27 +73,22 @@ clear config_string; end -%% % Defining SSD location if any +%% % Define SSD location if any. Comment the line if no SSD is present SSD_path = 'K:\Kilosort'; -if isunix - fname = KiloSortLinuxDir(basename,basepath,gpuDeviceNum); - ops.fproc = fname; -else - if isdir(SSD_path) - FileObj = java.io.File(SSD_path); - free_bytes = FileObj.getFreeSpace; - dat_file = dir(fullfile(basepath,[basename,'.dat'])); - if dat_file.bytes*1.11e-6); - kcoords = ops.kcoords; + chanMapConn = readNPY(fullfile('channel_map.npy'))+1; + kcoords = readNPY(fullfile('channel_shanks.npy')); % ops.kcoords; ia = rez.ia; spikeTimes = rez.st3(:,1); @@ -212,7 +226,7 @@ function Phy2Neurosuite(basepath,clustering_path) fid = fopen(ops.fbinary, 'r'); waveforms_all = []; - kcoords2 = unique(ops.kcoords); + kcoords2 = unique(kcoords); % ops.kcoords channel_order = {}; indicesTokeep = {}; @@ -223,7 +237,7 @@ function Phy2Neurosuite(basepath,clustering_path) kcoords3 = kcoords2(i); waveforms_all{i} = zeros(sum(kcoords==kcoords3),ops.nt0,size(rez.ia{i},1)); if exist('xml') - [channel_order,channel_index] = sort(xml.AnatGrps(kcoords2(i)).Channels+1); + [channel_order,channel_index] = sort(xml.SpkGrps(kcoords2(i)).Channels+1); [~,indicesTokeep{i},~] = intersect(chanMapConn,channel_order); %indicesTokeep{i} = connected_index(indicesTokeep{i}); From 1b8d7e3083c0b49ada48ab14732d49ca6cc69e1f Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Mon, 18 Mar 2019 20:50:35 -0400 Subject: [PATCH 22/35] Create loadClusteringData.m Adding function to load clustered data from either Phy or Neurosuite. This removes dependencies on other repositories. Spike data is loaded into memory using the buzcode compatible output format. A mat file is further created that allows for faster data loading in the subsequent calls. --- loadClusteringData.m | 189 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 loadClusteringData.m diff --git a/loadClusteringData.m b/loadClusteringData.m new file mode 100644 index 0000000..5b0bfc4 --- /dev/null +++ b/loadClusteringData.m @@ -0,0 +1,189 @@ +function spikes = loadClusteringData(baseName,clusteringMethod,clusteringPath,varargin) +% load clustered data from multiple pipelines [Phy, Klustakwik/Neurosuite] +% Buzcode compatible output. Saves output to a basename.spikes.cellinfo.mat file +% baseName: basename of the recording +% clusteringMethod: clustering method to handle different pipelines: ['phy','klustakwik'/'neurosuite'] +% clusteringPath: Path to the clustered data +% See description of varargin below + +% by Peter Petersen +% petersen.peter@gmail.com + +p = inputParser; +addParameter(p,'shanks',nan,@isnumeric); % shanks: Loading only a subset of shanks (only applicable to Klustakwik) +addParameter(p,'raw_clusters',false,@islogical); % raw_clusters: Load only a subset of clusters (might not work anymore as I have not used it for a long time) +addParameter(p,'forceReload',false,@islogical); % Reload spikes from original format? +addParameter(p,'saveMat',true,@islogical); % Save spikes to mat file? +addParameter(p,'getWaveforms',true,@islogical); % Get average waveforms? Only in effect for neurosuite/klustakwik format +parse(p,varargin{:}) + +shanks = p.Results.shanks; +raw_clusters = p.Results.raw_clusters; +forceReload = p.Results.forceReload; +saveMat = p.Results.saveMat; +getWaveforms = p.Results.getWaveforms; + +if exist(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) & ~forceReload + load(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) + if isfield(spikes,'ts') && (~isfield(spikes,'processinginfo') || (isfield(spikes,'processinginfo') && spikes.processinginfo.version < 3 && strcmp(spikes.processinginfo.function,'loadClusteringData') )) + forceReload = true; + disp('spikes.mat structure not up to date. Reloading spikes.') + else + disp('Loading existing spikes file') + end +end + +if forceReload + switch lower(clusteringMethod) + case {'klustakwik', 'neurosuite'} + disp('Loading Klustakwik clustered data') + unit_nb = 0; + spikes = []; + shanks_new = []; + if isnan(shanks) + fileList = dir(fullfile(clusteringPath,[baseName,'.res.*'])); + fileList = {fileList.name}; + for i = 1:length(fileList) + temp = strsplit(fileList{i},'.res.'); + shanks_new = [shanks_new,str2num(temp{2})]; + end + shanks = sort(shanks_new); + end + for shank = shanks + disp(['Loading shank #' num2str(shank) '/' num2str(length(shanks)) ]) + if ~raw_clusters + xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); + cluster_index = load(fullfile(clusteringPath, [baseName '.clu.' num2str(shank)])); + time_stamps = load(fullfile(clusteringPath,[baseName '.res.' num2str(shank)])); + if getWaveforms + fname = fullfile(clusteringPath,[baseName '.spk.' num2str(shank)]); + f = fopen(fname,'r'); + waveforms = 0.000195 * double(fread(f,'int16')); + samples = size(waveforms,1)/size(time_stamps,1); + electrodes = size(xml.ElecGp{shank},2); + waveforms = reshape(waveforms, [electrodes,samples/electrodes,length(waveforms)/samples]); + end + else + cluster_index = load(fullfile(clusteringPath, 'OriginalClus', [baseName '.clu.' num2str(shank)])); + time_stamps = load(fullfile(clusteringPath, 'OriginalClus', [baseName '.res.' num2str(shank)])); + end + cluster_index = cluster_index(2:end); + nb_clusters = unique(cluster_index); + nb_clusters2 = nb_clusters(nb_clusters > 1); + for i = 1:length(nb_clusters2) + unit_nb = unit_nb +1; + spikes.ts{unit_nb} = time_stamps(cluster_index == nb_clusters2(i)); + spikes.times{unit_nb} = spikes.ts{unit_nb}/xml.SampleRate; + spikes.shankID(unit_nb) = shank; + spikes.UID(unit_nb) = unit_nb; + spikes.cluID(unit_nb) = nb_clusters2(i); + spikes.cluster_index(unit_nb) = nb_clusters2(i); + spikes.total(unit_nb) = length(spikes.ts{unit_nb}); + if getWaveforms + spikes.filtWaveform_all{unit_nb} = mean(waveforms(:,:,cluster_index == nb_clusters2(i)),3); + spikes.filtWaveform_all_std{unit_nb} = permute(std(permute(waveforms(:,:,cluster_index == nb_clusters2(i)),[3,1,2])),[2,3,1]); + [~,index1] = max(max(spikes.filtWaveform_all{unit_nb}') - min(spikes.filtWaveform_all{unit_nb}')); + spikes.maxWaveformCh(unit_nb) = xml.ElecGp{shank}(index1); % index 0; + spikes.maxWaveformCh1(unit_nb) = xml.ElecGp{shank}(index1)+1; % index 1; + spikes.filtWaveform{unit_nb} = spikes.filtWaveform_all{unit_nb}(index1,:); + spikes.filtWaveform_std{unit_nb} = spikes.filtWaveform_all_std{unit_nb}(index1,:); + spikes.peakVoltage(unit_nb) = max(spikes.filtWaveform{unit_nb}) - min(spikes.filtWaveform{unit_nb}); + end + end + end + + clear cluster_index time_stamps + + case 'phy' + disp('Loading Phy clustered data') + xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); + spike_cluster_index = readNPY(fullfile(clusteringPath, 'spike_clusters.npy')); + spike_times = readNPY(fullfile(clusteringPath, 'spike_times.npy')); + spike_amplitudes = readNPY(fullfile(clusteringPath, 'amplitudes.npy')); + spike_clusters = unique(spike_cluster_index); + filename1 = fullfile(clusteringPath,'cluster_group.tsv'); + filename2 = fullfile(clusteringPath,'cluster_groups.csv'); + if exist(fullfile(clusteringPath, 'cluster_ids.npy')) + cluster_ids = readNPY(fullfile(clusteringPath, 'cluster_ids.npy')); + unit_shanks = readNPY(fullfile(clusteringPath, 'shanks.npy')); + peak_channel = readNPY(fullfile(clusteringPath, 'peak_channel.npy'))+1; + end + + if exist(filename1) == 2 + filename = filename1; + elseif exist(filename2) == 2 + filename = filename2; + else + disp('Phy: No cluster group file found') + end + delimiter = '\t'; + startRow = 2; + formatSpec = '%f%s%[^\n\r]'; + fileID = fopen(filename,'r'); + dataArray = textscan(fileID, formatSpec, 'Delimiter', delimiter, 'HeaderLines' ,startRow-1, 'ReturnOnError', false); + fclose(fileID); + spikes = []; + j = 1; + for i = 1:length(dataArray{1}) + if raw_clusters == 0 + if strcmp(dataArray{2}{i},'good') + if sum(spike_cluster_index == dataArray{1}(i))>0 + spikes.ids{j} = find(spike_cluster_index == dataArray{1}(i)); + spikes.ts{j} = double(spike_times(spikes.ids{j})); + spikes.times{j} = spikes.ts{j}/xml.SampleRate; + spikes.cluID(j) = dataArray{1}(i); + spikes.UID(j) = j; + if exist('cluster_ids') + cluster_id = find(cluster_ids == spikes.cluID(j)); + spikes.shankID(j) = double(unit_shanks(cluster_id)); + spikes.maxWaveformCh1(j) = double(peak_channel(cluster_id)); % index 1; + spikes.maxWaveformCh(j) = double(peak_channel(cluster_id))-1; % index 0; + end + spikes.total(j) = length(spikes.ts{j}); + spikes.amplitudes{j} = double(spike_amplitudes(spikes.ids{j})); + j = j+1; + end + end + else + spikes.ids{j} = find(spike_cluster_index == dataArray{1}(i)); + spikes.ts{j} = double(spike_times(spikes.ids{j})); + spikes.cluID(j) = dataArray{1}(i); + spikes.UID(j) = j; + spikes.amplitudes{j} = double(spike_amplitudes(spikes.ids{j}))'; + j = j+1; + end + end + case 'klustaViewa' + disp('Loading KlustaViewa clustered data') + units_to_exclude = []; + [spikes,~] = ImportKwikFile(baseName,clusteringPath,shanks,0,units_to_exclude); + end + + spikes.sessionName = baseName; + + % Generate spindices matrics + spikes.numcells = length(spikes.UID); + for cc = 1:spikes.numcells + groups{cc}=spikes.UID(cc).*ones(size(spikes.times{cc})); + end + + if spikes.numcells>0 + alltimes = cat(1,spikes.times{:}); groups = cat(1,groups{:}); %from cell to array + [alltimes,sortidx] = sort(alltimes); groups = groups(sortidx); %sort both + spikes.spindices = [alltimes groups]; + end + + % Attaching info about how the spikes structure was generated + spikes.processinginfo.function = 'loadClusteringData'; + spikes.processinginfo.version = 3.1; + spikes.processinginfo.date = now; + spikes.processinginfo.params.forceReload = forceReload; + spikes.processinginfo.params.shanks = shanks; + spikes.processinginfo.params.raw_clusters = raw_clusters; + spikes.processinginfo.params.getWaveforms = getWaveforms; + + % Saving output to a buzcode compatible spikes file. + if saveMat + save(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat']),'spikes') + end +end From e1377a8a3777a7b7a53ec7a9e9c4559a9284b91a Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Sun, 24 Mar 2019 11:18:24 -0400 Subject: [PATCH 23/35] Autoclustering implemented in the wrapper Now all settings are done with the input check, including SSD_path and CreateSubdirectory. The autoclustering is turned off by default. --- KiloSortWrapper.m | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 44a2440..2d84016 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -27,6 +27,7 @@ % (at your option) any later version. disp('Running Kilosort spike sorting with the Buzsaki lab wrapper') + %% If function is called without argument p = inputParser; basepath = cd; @@ -35,12 +36,18 @@ addParameter(p,'basepath',basepath,@ischar) addParameter(p,'basename',basename,@ischar) addParameter(p,'GPU_id',1,@isnumeric) +addParameter(p,'SSD_path','K:\Kilosort',@ischar) +addParameter(p,'CreateSubdirectory',1,@isnumeric) +addParameter(p,'performAutoCluster',0,@isnumeric) parse(p,varargin{:}) basepath = p.Results.basepath; basename = p.Results.basename; GPU_id = p.Results.GPU_id; +SSD_path = p.Results.SSD_path; +CreateSubdirectory = p.Results.CreateSubdirectory; +performAutoCluster = p.Results.performAutoCluster; cd(basepath) @@ -74,8 +81,6 @@ end %% % Define SSD location if any. Comment the line if no SSD is present -SSD_path = 'K:\Kilosort'; - if isdir(SSD_path) FileObj = java.io.File(SSD_path); free_bytes = FileObj.getFreeSpace; @@ -113,7 +118,7 @@ %% posthoc merge templates (under construction) % save matlab results file -CreateSubdirectory = 1; + if CreateSubdirectory timestamp = ['Kilosort_' datestr(clock,'yyyy-mm-dd_HHMMSS')]; savepath = fullfile(basepath, timestamp); @@ -134,6 +139,10 @@ disp('Converting to Phy format') rezToPhy_KSW(rez); end +% AutoCluster the Phy output +if performAutoCluster + PhyAutoClustering(savepath); +end %% export Neurosuite files if ops.export.neurosuite @@ -155,3 +164,4 @@ reset(gpudev) gpuDevice([]) disp('Kilosort Processing complete') + From d19fb86c993d289a67786626b634b2214f06ef92 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Tue, 26 Mar 2019 11:17:03 -0400 Subject: [PATCH 24/35] Small changes and bug fixes --- KiloSortWrapper.m | 50 ++++++++++++++++++----------------------- KilosortConfiguration.m | 2 +- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 2d84016..fea826f 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -4,18 +4,12 @@ % % USAGE % -% KiloSortWrapper() -% Should be run from the data folder, and file basenames are the -% same as the name as current directory +% KiloSortWrapper +% Run from data folder. File basenames must be the +% same as the name as current folder % % KiloSortWrapper(varargin) -% -% INPUTS -% basepath path to the folder containing the data -% basename file basenames (of the dat and xml files) -% config Specify a configuration file to use from the -% ConfigurationFiles folder. e.g. 'Omid' -% GPU_id Specify the GPU id +% Check varargin description below when input parameters are parsed % % Dependencies: KiloSort (https://github.com/cortex-lab/KiloSort) % @@ -25,20 +19,21 @@ % it under the terms of the GNU General Public License as published by % the Free Software Foundation; either version 2 of the License, or % (at your option) any later version. -disp('Running Kilosort spike sorting with the Buzsaki lab wrapper') +disp('Running Kilosort spike sorting with the Buzsaki lab wrapper') -%% If function is called without argument +%% Parsing inputs p = inputParser; basepath = cd; [~,basename] = fileparts(basepath); -addParameter(p,'basepath',basepath,@ischar) -addParameter(p,'basename',basename,@ischar) -addParameter(p,'GPU_id',1,@isnumeric) -addParameter(p,'SSD_path','K:\Kilosort',@ischar) -addParameter(p,'CreateSubdirectory',1,@isnumeric) -addParameter(p,'performAutoCluster',0,@isnumeric) +addParameter(p,'basepath',basepath,@ischar) % path to the folder containing the data +addParameter(p,'basename',basename,@ischar) % file basenames (of the dat and xml files) +addParameter(p,'GPU_id',1,@isnumeric) % Specify the GPU_id +addParameter(p,'SSD_path','K:\Kilosort',@ischar) % Path to SSD disk. Make it empty to disable SSD +addParameter(p,'CreateSubdirectory',1,@isnumeric) % Puts the Kilosort output into a subfolder +addParameter(p,'performAutoCluster',0,@isnumeric) % Performs PhyAutoCluster once Kilosort is complete when exporting to Phy +addParameter(p,'config','',@ischar) % Specify a configuration file to use from the ConfigurationFiles folder. e.g. 'Omid' parse(p,varargin{:}) @@ -48,6 +43,7 @@ SSD_path = p.Results.SSD_path; CreateSubdirectory = p.Results.CreateSubdirectory; performAutoCluster = p.Results.performAutoCluster; +config = p.Results.config; cd(basepath) @@ -66,11 +62,8 @@ %% Loading configurations XMLFilePath = fullfile(basepath, [basename '.xml']); -% if exist(fullfile(basepath,'StandardConfig.m'),'file') %this should actually be unnecessary -% addpath(basepath); -% end -ec = exist('config'); -if ec ~= 1 + +if isempty(config) disp('Running Kilosort with standard settings') ops = KilosortConfiguration(XMLFilePath); else @@ -80,7 +73,7 @@ clear config_string; end -%% % Define SSD location if any. Comment the line if no SSD is present +%% % Checks SSD location for sufficient space if isdir(SSD_path) FileObj = java.io.File(SSD_path); free_bytes = FileObj.getFreeSpace; @@ -138,10 +131,11 @@ if ops.export.phy disp('Converting to Phy format') rezToPhy_KSW(rez); -end -% AutoCluster the Phy output -if performAutoCluster - PhyAutoClustering(savepath); + + % AutoClustering the Phy output + if performAutoCluster + PhyAutoClustering(savepath); + end end %% export Neurosuite files diff --git a/KilosortConfiguration.m b/KilosortConfiguration.m index 01a3524..778819f 100644 --- a/KilosortConfiguration.m +++ b/KilosortConfiguration.m @@ -59,7 +59,7 @@ ops.splitT = .1; % lower threshold for splitting (.1) % options for initializing spikes from data -ops.initialize = 'no'; %'fromData' or 'no' +ops.initialize = 'fromData'; %'fromData' or 'no' ops.spkTh = -4; % spike threshold in standard deviations (4) ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1]) ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6]) From c04b6c8e8bf4626f73b971bca356a76994a5e5e4 Mon Sep 17 00:00:00 2001 From: Brendon Watson Date: Fri, 29 Mar 2019 23:04:13 -0400 Subject: [PATCH 25/35] Fixing bugs for non-buzsaki users Hard drive stuff rez2phy call Waveform extraction issue --- Phy2Neurosuite.m | 30 +++++++++++++++++++----------- private/KiloSortLinuxDir.m | 2 -- rezToPhy_KSW.m | 2 +- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/Phy2Neurosuite.m b/Phy2Neurosuite.m index fee8eb2..8a57fdc 100644 --- a/Phy2Neurosuite.m +++ b/Phy2Neurosuite.m @@ -18,10 +18,10 @@ function Phy2Neurosuite(basepath,clustering_path) % petersen.peter@gmail.com t1 = tic; -cd(clustering_path) - -if exist('rez.mat') - load('rez.mat') +% cd(clustering_path) +rezpath = fullfile(clustering_path,'rez.mat'); +if exist(rezpath) + load(rezpath) spikeTimes = uint64(rez.st3(:,1)); % uint64 if isfield(rez.ops,'basename') basename = rez.ops.basename; @@ -235,12 +235,16 @@ function Phy2Neurosuite(basepath,clustering_path) for i = 1:length(kcoords2) kcoords3 = kcoords2(i); - waveforms_all{i} = zeros(sum(kcoords==kcoords3),ops.nt0,size(rez.ia{i},1)); - if exist('xml') - [channel_order,channel_index] = sort(xml.SpkGrps(kcoords2(i)).Channels+1); - [~,indicesTokeep{i},~] = intersect(chanMapConn,channel_order); - - %indicesTokeep{i} = connected_index(indicesTokeep{i}); + if i<=length(rez.ia)%case where no clus in last group... like if last group was non-ephys + waveforms_all{i} = zeros(sum(kcoords==kcoords3),ops.nt0,size(rez.ia{i},1)); + if exist('xml') + [channel_order,channel_index] = sort(xml.AnatGrps(kcoords2(i)).Channels+1); + [~,indicesTokeep{i},~] = intersect(chanMapConn,channel_order); + + %indicesTokeep{i} = connected_index(indicesTokeep{i}); + end + else + kcoords2(i) = []; end end @@ -272,7 +276,11 @@ function Phy2Neurosuite(basepath,clustering_path) buff(:, nsampcurr+1:NTbuff) = repmat(buff(:,nsampcurr), 1, NTbuff-nsampcurr); end if ops.GPU - dataRAW = gpuArray(buff); + try%control for if gpu is busy + dataRAW = gpuArray(buff); + catch + dataRAW = buff; + end else dataRAW = buff; end diff --git a/private/KiloSortLinuxDir.m b/private/KiloSortLinuxDir.m index 765c876..e8f5759 100644 --- a/private/KiloSortLinuxDir.m +++ b/private/KiloSortLinuxDir.m @@ -63,14 +63,12 @@ mountSSD = mnt(SSD); freespaceSSD = freespace(SSD); - mountHD = mnt(~SSD); freespaceHD = freespace(~SSD); if any( (freespaceSSD-datsize) > .5) %save 500MB on the SSD, can be decreased - [~,b] = max(freespaceSSD-datsize); fname = [mountSSD{b} '/temp_wh_' num2str(gpuDeviceNum) '.dat']; diff --git a/rezToPhy_KSW.m b/rezToPhy_KSW.m index 68f67f7..5d42ee6 100644 --- a/rezToPhy_KSW.m +++ b/rezToPhy_KSW.m @@ -1,5 +1,5 @@ function [spikeTimes, clusterIDs, amplitudes, templates, templateFeatures, ... - templateFeatureInds, pcFeatures, pcFeatureInds] = rezToPhy(rez,savepath) + templateFeatureInds, pcFeatures, pcFeatureInds] = rezToPhy_KSW(rez,savepath) % pull out results from kilosort's rez to either return to workspace or to % save in the appropriate format for the phy GUI to run on. If you provide % a savePath it should be a folder, and you will need to have npy-matlab From 450ee6d0380580f5b8e11e876741db7fec6e8af2 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Tue, 16 Apr 2019 11:01:28 -0400 Subject: [PATCH 26/35] Documentation updates and CCG dependencies --- KiloSortWrapper.m | 6 +- PhyAutoClustering.m | 80 ++++++++--------- loadClusteringData.m | 6 +- private/CCG.m | 159 +++++++++++++++++++++++++++++++++ private/CCGHeart.c | 203 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 409 insertions(+), 45 deletions(-) create mode 100644 private/CCG.m create mode 100644 private/CCGHeart.c diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index fea826f..eca8427 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -12,6 +12,10 @@ % Check varargin description below when input parameters are parsed % % Dependencies: KiloSort (https://github.com/cortex-lab/KiloSort) +% +% The AutoClustering requires the CCGHeart to be compile. +% Go to the private folder of the wrapper and type: +% mex -O CCGHeart.c % % Copyright (C) 2016 Brendon Watson and the Buzsakilab % @@ -32,7 +36,7 @@ addParameter(p,'GPU_id',1,@isnumeric) % Specify the GPU_id addParameter(p,'SSD_path','K:\Kilosort',@ischar) % Path to SSD disk. Make it empty to disable SSD addParameter(p,'CreateSubdirectory',1,@isnumeric) % Puts the Kilosort output into a subfolder -addParameter(p,'performAutoCluster',0,@isnumeric) % Performs PhyAutoCluster once Kilosort is complete when exporting to Phy +addParameter(p,'performAutoCluster',0,@isnumeric) % Performs PhyAutoCluster once Kilosort is complete when exporting to Phy. addParameter(p,'config','',@ischar) % Specify a configuration file to use from the ConfigurationFiles folder. e.g. 'Omid' parse(p,varargin{:}) diff --git a/PhyAutoClustering.m b/PhyAutoClustering.m index a2d70a1..4397e34 100644 --- a/PhyAutoClustering.m +++ b/PhyAutoClustering.m @@ -1,55 +1,51 @@ function PhyAutoClustering(clusteringpath,varargin) -% AutoClustering automtically cleans Kilosort output in phy format defined -% by a clusteringpath. -% % INPUT: % clusteringpath: char % -% optional: +% Optional: % AutoClustering(clusteringpath,elec,dim) % where dim is the number of channels in electro group (if not % defined, will read the first line of the fet file % -% AutoClustering is meant to clean the output of KlustaKwik. The first -% thing it does is to separate electrical artifacts and MUA from putative -% isolated units. To do so, it sorts out units which have no clear -% refractory period (based on Hill, Mehta and Kleinfeld, J Neurosci., -% 2012). Threshold can be set in the parameter section of this file -% ("Rogue spike threshold"). Then, it separates electrical -% artifats from MUA based on the assumption that electrical artifacts are -% highly correlated on the different channels: the average waveform of at -% least one channel has to be different from the across-channel average -% waveform by a certrain amount of total variance (can be set in the -% parameter section, "Deviation from average spike threshold") +% Requirements: +% CCGHeart has to be compiled. Go to the private folder of the wrapper and type: +% mex -O CCGHeart.c +% % +% PhyAutoClustering is cleaning the output of Kilosort and labels the units accordingly: +% 1. Removing spikes with large artifacts: +% Uses the amplitude vector and removes spikes with an amplitude larger +% than amplitude_thr, where the spikes are convoluted to get time points +% with greater general amplitude than the amplitude_thr. Artifact spikes +% are grouped and labeled 'artifacts'. % -% Once the program has determined which of the clusters are putative -% isolated units, it tries to merge them based on waveform similarity -% (mahalanobis distance) and quality of the refractory period in the new -% merged cluster (or "Inter Common Spike Interval" from MS Fee et al. J -% Neurosci. Meth., 1996) +% 2. Mahal artifact removal +% Uses the private PCAs and removes any spikes with a larger mahal +% distance than mahal_thr. Removes spikes are labeled as 'mua'. +% 3. Determines MUA (labeled 'mua') +% +% 4. Removes noise artifacts (labeled 'noise') +% Sorts out units which have no clear refractory period (based on Hill, +% Mehta and Kleinfeld, J Neurosci., 2012). Threshold can be set in the +% parameter section of this file ("Rogue spike threshold"). Then, it +% separates electrical artifats from MUA based on the assumption that +% electrical artifacts are highly correlated on the different channels: +% the average waveform of at least one channel has to be different from +% the across-channel average waveform by a certrain amount of total +% variance (can be set in the parameter section, "Deviation from average +% spike threshold") (including units with less than 100 spikes): % -% Original script by Adrien Peyrache, 2012. -% Many modifications for Phy processing pipeline by -% Yuta Senzai and Peter Petersen - - -% if ~isempty(varargin) -% dim = varargin{1}; -% dim = dim(:); -% if any(double(int16(dim))~=dim) -% error('Number of dimensions must be an integer') -% end -% -% if size(dim,1) ~= length(elec) && length(dim) ~=1 -% error('Number of dimensions must be a vector of the same length as electrode vector or a single value') -% end -% if length(dim) == 1 -% dim = dim*ones(length(elec),1); -% end -% else -% dim = zeros(length(elec),1); -% end +% 5. Merging potential units based on CCGs +% Once the program has determined which of the clusters are putative +% isolated units, it tries to merge them based on waveform similarity +% (mahalanobis distance) and quality of the refractory period in the new +% merged cluster (or "Inter Common Spike Interval" from MS Fee et al. +% JNeurosci. Meth., 1996). +% +% By Adrien Peyrache, Peter Petersen & Yuta Senzai + + + % Refractory period in msec tR = 1.5; % 1.5 @@ -191,7 +187,7 @@ function PhyAutoClustering(clusteringpath,varargin) muaIx = find(fractRogue>rogThres & ~(meanR >= rThres & maxPwRatio < mprThres) & h>=100); goodIx = find(fractRogue<=rogThres & ~(meanR >= rThres & maxPwRatio < mprThres) & h>=100); % 100 or samlenum -% Saving clusters to cluster_group.tsv +% Saving clusters to cluster_group.tsv (Phy) fid = fopen(fullfile(clusteringpath,'cluster_group.tsv'),'w'); fwrite(fid, sprintf('cluster_id\t%s\r\n', 'group')); for ii=1:length(cids) diff --git a/loadClusteringData.m b/loadClusteringData.m index 5b0bfc4..c26cb4a 100644 --- a/loadClusteringData.m +++ b/loadClusteringData.m @@ -11,8 +11,8 @@ p = inputParser; addParameter(p,'shanks',nan,@isnumeric); % shanks: Loading only a subset of shanks (only applicable to Klustakwik) -addParameter(p,'raw_clusters',false,@islogical); % raw_clusters: Load only a subset of clusters (might not work anymore as I have not used it for a long time) -addParameter(p,'forceReload',false,@islogical); % Reload spikes from original format? +addParameter(p,'raw_clusters',false,@islogical); % raw_clusters: Load only a subset of clusters (might not work anymore as it has not been tested for a long time) +addParameter(p,'forceReload',false,@islogical); % Reload spikes from original format and resave the .spikes.mat file? addParameter(p,'saveMat',true,@islogical); % Save spikes to mat file? addParameter(p,'getWaveforms',true,@islogical); % Get average waveforms? Only in effect for neurosuite/klustakwik format parse(p,varargin{:}) @@ -31,6 +31,8 @@ else disp('Loading existing spikes file') end +else + forceReload = true; end if forceReload diff --git a/private/CCG.m b/private/CCG.m new file mode 100644 index 0000000..8315236 --- /dev/null +++ b/private/CCG.m @@ -0,0 +1,159 @@ +%CCG - Compute multiple cross- and auto-correlograms +% +% USAGE +% +% [ccg,t] = CCG(times,groups,) +% +% times times of all events +% (alternate) - can be {Ncells} array of [Nspikes] +% spiketimes for each cell +% NOTE: spiketimes in SECONDS. +% groups group IDs for each event in time list (should be +% integers 1:nGroups) +% (alternate) - [] +% optional list of property-value pairs (see table below) +% +% ========================================================================= +% Properties Values +% ------------------------------------------------------------------------- +% 'binSize' bin size in s (default = 0.01) +% 'duration' duration in s of each xcorrelogram (default = 2) +% 'norm' normalization of the CCG, 'counts' or 'rate' (DL added 8/1/17) +% 'counts' gives raw event/spike count, +% 'rate' returns CCG in units of spks/second (default: counts) +% ========================================================================= +% +% +% OUTPUT +% ccg [t x ngroups x ngroups] matrix where ccg(t,i,j) is the +% number (or rate) of events of group j at time lag t with +% respect to reference events from group i +% t time lag vector (units: seonds) +% +% SEE +% +% See also ShortTimeCCG. + +% Copyright (C) 2012 by Michaël Zugaro +% +% This program is free software; you can redistribute it and/or modify +% it under the terms of the GNU General Public License as published by +% the Free Software Foundation; either version 3 of the License, or +% (at your option) any later version. + +function [ccg,t] = CCG(times,groups,varargin) + +% Default values +duration = 2; +binSize = 0.01; +Fs = 1/20000; +normtype = 'counts'; + +% Option for spike times to be in {Ncells} array of spiketimes DL2017 +if iscell(times) && isempty(groups) + numcells = length(times); + for cc = 1:numcells + groups{cc}=cc.*ones(size(times{cc})); + end + times = cat(1,times{:}); groups = cat(1,groups{:}); +end + +%Sort +[times,sortidx] = sort(times); +groups = groups(sortidx); + +% Check parameters +if nargin < 2, + error('Incorrect number of parameters (type ''help CCG'' for details).'); +end +%if ~isdvector(times), +% error('Parameter ''times'' is not a real-valued vector (type ''help CCG'' for details).'); +%end +if ~isdscalar(groups) && ~isdvector(groups), + error('Parameter ''groups'' is not a real-valued scalar or vector (type ''help CCG'' for details).'); +end +if ~isdscalar(groups) && length(times) ~= length(groups), + error('Parameters ''times'' and ''groups'' have different lengths (type ''help CCG'' for details).'); +end +groups = groups(:); +times = times(:); + +% Parse parameter list +for i = 1:2:length(varargin), + if ~ischar(varargin{i}), + error(['Parameter ' num2str(i+2) ' is not a property (type ''help CCG'' for details).']); + end + switch(lower(varargin{i})), + case 'binsize', + binSize = varargin{i+1}; + %if ~isdscalar(binSize,'>0'), + % error('Incorrect value for property ''binSize'' (type ''help CCG'' for details).'); + % end + case 'duration', + duration = varargin{i+1}; + if ~isdscalar(duration,'>0'), + error('Incorrect value for property ''duration'' (type ''help CCG'' for details).'); + end + + case 'Fs', + Fs = varargin{i+1}; + if ~isdscalar(Fs,'>0'), + error('Incorrect value for property ''Fs'' (type ''help CCG'' for details).'); + end + case 'norm' + normtype = varargin{i+1}; + + end +end + + + +% Number of groups, number of bins, etc. +if length(groups) == 1, + groups = ones(length(times),1); + nGroups = 1; +else + nGroups = max(unique(groups)); +end + + +halfBins = round(duration/binSize/2); +nBins = 2*halfBins+1; +t = (-halfBins:halfBins)'*binSize; +times = round(times/Fs); +binSize_Fs = round(binSize/Fs); +if length(times) <= 1, + % ---- MODIFIED BY EWS, 1/2/2014 ---- + % *** Use unsigned integer format to save memory *** + ccg = uint16(zeros(nBins,nGroups,nGroups)); + % ----------------------------------- + return +end + +% Compute CCGs +nEvents = length(times); +% + +counts = double(CCGHeart(times,uint32(groups),binSize_Fs,uint32(halfBins))); +% ----------------------------------- +% +% Reshape the results +n = max(groups); +counts = reshape(counts,[nBins n n]); + + +if n < nGroups, + counts(nBins,nGroups,nGroups) = 0; +end + +%Rate normalization: counts/numREFspikes/dt to put in units of spikes/s. DL +switch normtype + case 'rate' + for gg = 1:nGroups + numREFspikes = sum(groups==gg);%number of reference events for group + counts(:,gg,:) = counts(:,gg,:)./numREFspikes./binSize; + end +end + + +ccg = flipud(counts); \ No newline at end of file diff --git a/private/CCGHeart.c b/private/CCGHeart.c new file mode 100644 index 0000000..f959e24 --- /dev/null +++ b/private/CCGHeart.c @@ -0,0 +1,203 @@ +/* CCGEngine.c /* +/* This is a bare-bones C program whos purpose is to compute + Multi-unit cross correlograms quickly. Not intended for + use on its own. It is designed to be wrapped by a MATLAB + function. + + Usage - [CCG, PAIRS] = CCGEngine(TIMES, MARKS, BINSIZE, HALFBINS) + TIMES ( is the name of a binary file containing N doubles giving the spike times + MARKS is the name of a binary file containing N unsigned ints giving the spike markers. Don't use zero! + BINSIZE is a number giving the size of the ccg bins in TIMES units + HALFBINS is the number of bins to compute - so there are a total of nBins = 1+2*HALFBINS bins + + These should be: double, uint32,double,uint32 + + NB The spikes MUST be sorted. + + CCG contains unsigned ints containing the counts in each bin + It is like a 3d array, indexed by [nBins*nMarks*Mark1 + nBins*Mark2 + Bin] + + PAIRS contains the spike numbers of every pair of spikes included in the ccg. + + If you think this program is anal, you're right. Use the MATLAB wrapper instead. + */ + + +#include "mex.h" +#include "matrix.h" +#include +#include +#include + +#define CHECK +#define STRLEN 10000 +#define PAIRBLOCKSIZE 1000000000 + +unsigned int *Pairs; +unsigned int PairCnt, PairSz; +void AddPair(unsigned int n1, unsigned int n2) { + unsigned int *pui; + + if (PairSz==0) { +/* mexPrintf("Allocating pair memory\n"); */ + Pairs = mxMalloc(PAIRBLOCKSIZE*sizeof(unsigned int)); + PairSz = PAIRBLOCKSIZE; + if (!Pairs) mexErrMsgTxt("Could not allocate memory for pairs"); + } + /* check if array is full, if so add more memory*/ + if(PairCnt>=PairSz) { +/* mexPrintf("Reallocating pair memory ... "); + PairSz += PAIRBLOCKSIZE; + pui = mxRealloc(Pairs, PairSz); + mexPrintf("got %x\n", pui); + if (!pui) { + mxFree(Pairs); + + mexErrMsgTxt("Could not reallocate memory for pairs"); + } + Pairs = pui; +*/ + mexPrintf("\n Number of pairs %d\n",PairCnt); + mexErrMsgTxt("Too many pairs"); + + } + Pairs[PairCnt++] = n1; + Pairs[PairCnt++] = n2; +} + + + +void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { + + unsigned int nSpikes, nMarks, HalfBins, nBins, i, CountArraySize, CountIndex; + double *Times; + double BinSize, FurthestEdge; + unsigned int *Marks, Mark, *Count; + unsigned int CenterSpike, Mark1, Mark2, Bin; + int SecondSpike; /* we want to let it go negative so we can stop it there... */ + double Time1, Time2; + char errstr[STRLEN]; + + /* global variables are not initialized on each call to mex fn!! */ + PairCnt = 0; PairSz = 0; + + if (nrhs!=4) { + mexErrMsgTxt("Must have 4 arguments\nBut listen: You don't want to use this program.\nUse the MATLAB wrapper function CCG instead.\n"); + } + if (mxGetClassID(prhs[0])!=mxDOUBLE_CLASS || mxGetClassID(prhs[1])!=mxUINT32_CLASS + || mxGetClassID(prhs[2])!=mxDOUBLE_CLASS || mxGetClassID(prhs[3])!=mxUINT32_CLASS ) { + mexErrMsgTxt("Arguments are wrong type\n"); + } + + + /* get arguments */ + Times = mxGetPr(prhs[0]); + Marks = (unsigned int *) mxGetPr(prhs[1]); + nSpikes = mxGetNumberOfElements(prhs[0]); + if (mxGetNumberOfElements(prhs[1])!=nSpikes) mexErrMsgTxt("Number of marks ~= number of spikes"); + BinSize = mxGetScalar(prhs[2]); + HalfBins = (unsigned int) mxGetScalar(prhs[3]); + + /* derive other constants */ + nBins = 1+2*HalfBins; + FurthestEdge = BinSize * (HalfBins + 0.5); + + + /* count nMarks */ + nMarks = 0; + for(i=0; inMarks) nMarks = Mark; + if (Mark==0) { + mexErrMsgTxt("CCGEngine: No zeros allowed in Marks"); + abort(); + } + } + + /* allocate output array */ + CountArraySize = nMarks * nMarks * nBins; + plhs[0] = mxCreateNumericMatrix(CountArraySize, 1, mxUINT32_CLASS, mxREAL); + Count = (unsigned int *) mxGetPr(plhs[0]); + + if (!Times || !Marks || !Count) { + mexErrMsgTxt("CCGEngine could not allocate memory!\n"); + } + + /* Now the main program .... */ + + + for(CenterSpike=0; CenterSpike=0; SecondSpike--) { + Time2 = Times[SecondSpike]; + + /* check if we have left the interesting region */ + if(fabs(Time1 - Time2) > FurthestEdge) break; + + /* calculate bin */ + Bin = HalfBins + (int)(floor(0.5+(Time2-Time1)/BinSize)); + + Mark2 = Marks[SecondSpike]; + CountIndex = nBins*nMarks*(Mark1-1) + nBins*(Mark2-1) + Bin; +#ifdef CHECK + if (CountIndex<0 || CountIndex >= CountArraySize) { + sprintf(errstr, "err a: t1 %f t2 %f m1 %d m2 %d Bin %d, index %d out of bounds", + Time1, Time2, Mark1, Mark2, Bin, CountIndex); + mexErrMsgTxt(errstr); + } +#endif + + /* increment count */ + Count[CountIndex]++; + if (nlhs>=2) AddPair(CenterSpike, SecondSpike); + + } + + /* Now do the same thing going forward... */ + for(SecondSpike=CenterSpike+1; SecondSpike= FurthestEdge) break; + + /* calculate bin */ + Bin = HalfBins + (unsigned int)(floor(0.5+(Time2-Time1)/BinSize)); + + Mark2 = Marks[SecondSpike]; + CountIndex = nBins*nMarks*(Mark1-1) + nBins*(Mark2-1) + Bin; + +#ifdef CHECK + if (CountIndex<0 || CountIndex >= CountArraySize) { + sprintf(errstr, "err b: t1 %f t2 %f m1 %d m2 %d Bin %d, index %d out of bounds", + Time1, Time2, Mark1, Mark2, Bin, CountIndex); + mexErrMsgTxt(errstr); + } +#endif + + /* increment count */ + Count[CountIndex]++; + if (nlhs>=2) AddPair(CenterSpike, SecondSpike); + + } + } + + if (nlhs>=2) { +/* if (PairCnt==0){ + sprintf(errstr, "No pairs for these spike trains"); + mexErrMsgTxt(errstr); + + }else{*/ + + plhs[1] = mxCreateNumericMatrix(PairCnt, 1, mxUINT32_CLASS, mxREAL); + memcpy(mxGetPr(plhs[1]), (void *)Pairs, PairCnt*sizeof(unsigned int)); + mxFree(Pairs); + + } + + + + /* sayonara */ +} From ba952bf0e88fdc9d9802920fd7c7a56a568815e0 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Mon, 22 Apr 2019 17:02:23 -0400 Subject: [PATCH 27/35] Getting waveforms as well --- Phy2Neurosuite.m | 2 +- loadClusteringData.m | 85 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/Phy2Neurosuite.m b/Phy2Neurosuite.m index 8a57fdc..101cb16 100644 --- a/Phy2Neurosuite.m +++ b/Phy2Neurosuite.m @@ -18,7 +18,7 @@ function Phy2Neurosuite(basepath,clustering_path) % petersen.peter@gmail.com t1 = tic; -% cd(clustering_path) +cd(clustering_path) rezpath = fullfile(clustering_path,'rez.mat'); if exist(rezpath) load(rezpath) diff --git a/loadClusteringData.m b/loadClusteringData.m index c26cb4a..ffda148 100644 --- a/loadClusteringData.m +++ b/loadClusteringData.m @@ -3,11 +3,14 @@ % Buzcode compatible output. Saves output to a basename.spikes.cellinfo.mat file % baseName: basename of the recording % clusteringMethod: clustering method to handle different pipelines: ['phy','klustakwik'/'neurosuite'] -% clusteringPath: Path to the clustered data +% clusteringPath: path to the clustered data % See description of varargin below % by Peter Petersen % petersen.peter@gmail.com +% +% Version history +% 3.2 waveforms for phy data extracted from the raw dat p = inputParser; addParameter(p,'shanks',nan,@isnumeric); % shanks: Loading only a subset of shanks (only applicable to Klustakwik) @@ -92,8 +95,8 @@ spikes.peakVoltage(unit_nb) = max(spikes.filtWaveform{unit_nb}) - min(spikes.filtWaveform{unit_nb}); end end - end - + end + clear cluster_index time_stamps case 'phy' @@ -109,6 +112,12 @@ cluster_ids = readNPY(fullfile(clusteringPath, 'cluster_ids.npy')); unit_shanks = readNPY(fullfile(clusteringPath, 'shanks.npy')); peak_channel = readNPY(fullfile(clusteringPath, 'peak_channel.npy'))+1; + if exist(fullfile(clusteringPath, 'rez.mat')) + load(fullfile(clusteringPath, 'rez.mat')) + temp = find(rez.connected); + peak_channel = temp(peak_channel); + clear rez temp + end end if exist(filename1) == 2 @@ -155,6 +164,74 @@ j = j+1; end end + + if getWaveforms % get waveforms + timerVal = tic; + nPull = 1000; % number of spikes to pull out + wfWin = 0.006; % Larger size of waveform windows for filterning + wfWinKeep = 0.001; + filtFreq = 500; + hpFilt = designfilt('highpassiir','FilterOrder',3, 'PassbandFrequency',filtFreq,'PassbandRipple',0.1, 'SampleRate',xml.SampleRate); + [b1, a1] = butter(3, filtFreq/xml.SampleRate*2, 'high'); + + f = waitbar(0,'Getting waveforms...'); + wfWin = round((wfWin * xml.SampleRate)/2); + + for ii = 1 : size(spikes.times,2) + waitbar(ii/size(spikes.times,2),f,['Pulling out waveforms (',num2str(ii),'/',num2str(size(spikes.times,2)),')']); + spkTmp = spikes.ts{ii}; + if length(spkTmp) > nPull + spkTmp = spkTmp(randperm(length(spkTmp))); + spkTmp = sort(spkTmp(1:nPull)); + end + + + % Determines the maximum waveform channel + wf = zeros((wfWin * 2)+1,xml.nChannels,100); + wfF = zeros((wfWin * 2)+1,xml.nChannels); + for jj = 1 : 100 + wf(:,:,jj) = double(LoadBinary([baseName '.dat'],'offset',spkTmp(jj) - wfWin,'nChannels',xml.nChannels,'precision','int16','frequency',xml.SampleRate,'samples',(wfWin * 2)+1)); + % wf(:,:,jj) = double(bz_LoadBinary([baseName '.dat'],'offset',spkTmp(jj) - (wfWin), 'samples',(wfWin * 2)+1,'frequency',xml.SampleRate,'nChannels',xml.nChannels)); + end + wf = 0.195 * mean(wf,3); + + for jj = 1 : size(wf,2) + wfF(:,jj) = filtfilt(b1, a1, wf(:,jj)); + end + [~, spikes.maxWaveformCh1(ii)] = max(abs(wfF(wfWin,:))); + spikes.maxWaveformCh(ii) = spikes.maxWaveformCh1(ii)-1; + + % Assigning shankID to the unit + for jj = 1:size(xml.AnatGrps,2) + if xml.AnatGrps(jj).Channels == spikes.maxWaveformCh(ii) + spikes.shankID(ii) = jj; + end + end + + % Pulls the waveforms from the dat + wf = zeros((wfWin * 2)+1,length(spkTmp)); + wfF = zeros((wfWin * 2)+1,length(spkTmp)); + for jj = 1 : length(spkTmp) + wf(:,jj) = double(LoadBinary([baseName '.dat'],'offset',spkTmp(jj) - wfWin,'nChannels',xml.nChannels,'channels',spikes.maxWaveformCh1(ii),'precision','int16','frequency',xml.SampleRate,'samples',(wfWin * 2)+1)); + wfF(:,jj) = filtfilt(b1, a1, wf(:,jj)); + end + wf2 = mean(0.195 * wf,2); + rawWaveform{ii} = detrend(wf2 - mean(wf2))'; + rawWaveform_std{ii} = std(0.195 * (wf-mean(wf))'); + filtWaveform{ii} = mean(0.195 * wfF,2)'; + filtWaveform_std{ii} = std(0.195 * wfF'); + + spikes.rawWaveform{ii} = rawWaveform{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); % keep only +- 1ms of waveform + spikes.rawWaveform_std{ii} = rawWaveform_std{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); + spikes.filtWaveform{ii} = filtWaveform{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); + spikes.filtWaveform_std{ii} = filtWaveform_std{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); + spikes.timeWaveform{ii} = (-wfWinKeep:1/xml.SampleRate:wfWinKeep)*1000; + spikes.peakVoltage(ii) = max(spikes.filtWaveform{ii})-min(spikes.filtWaveform{ii}); + end + close(f) + toc(timerVal) + end + case 'klustaViewa' disp('Loading KlustaViewa clustered data') units_to_exclude = []; @@ -177,7 +254,7 @@ % Attaching info about how the spikes structure was generated spikes.processinginfo.function = 'loadClusteringData'; - spikes.processinginfo.version = 3.1; + spikes.processinginfo.version = 3.2; spikes.processinginfo.date = now; spikes.processinginfo.params.forceReload = forceReload; spikes.processinginfo.params.shanks = shanks; From 77dfc457955e759ecf9f70968d167962efef3d95 Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Fri, 17 May 2019 09:57:57 -0400 Subject: [PATCH 28/35] Updates to loadClustering script --- loadClusteringData.m | 276 ++++++++++++++++++++++++++++++------------- 1 file changed, 193 insertions(+), 83 deletions(-) diff --git a/loadClusteringData.m b/loadClusteringData.m index ffda148..0dbcb53 100644 --- a/loadClusteringData.m +++ b/loadClusteringData.m @@ -1,23 +1,32 @@ -function spikes = loadClusteringData(baseName,clusteringMethod,clusteringPath,varargin) -% load clustered data from multiple pipelines [Phy, Klustakwik/Neurosuite] +function spikes = loadClusteringData(clusteringPath,clusteringMethod,varargin) +% Load clustered data from multiple pipelines [Current options: Phy, Klustakwik/Neurosuite] % Buzcode compatible output. Saves output to a basename.spikes.cellinfo.mat file -% baseName: basename of the recording -% clusteringMethod: clustering method to handle different pipelines: ['phy','klustakwik'/'neurosuite'] % clusteringPath: path to the clustered data +% clusteringMethod: clustering method to handle different pipelines: ['phy','klustakwik'/'neurosuite'] + % See description of varargin below % by Peter Petersen % petersen.peter@gmail.com -% + % Version history % 3.2 waveforms for phy data extracted from the raw dat +% 3.3 waveforms extracted from raw dat using memmap function. Interval and bad channels bugs fixed as well +% 3.4 bug fix which gave misaligned waveform extraction from raw dat. Plot improvements of waveforms p = inputParser; addParameter(p,'shanks',nan,@isnumeric); % shanks: Loading only a subset of shanks (only applicable to Klustakwik) addParameter(p,'raw_clusters',false,@islogical); % raw_clusters: Load only a subset of clusters (might not work anymore as it has not been tested for a long time) addParameter(p,'forceReload',false,@islogical); % Reload spikes from original format and resave the .spikes.mat file? addParameter(p,'saveMat',true,@islogical); % Save spikes to mat file? -addParameter(p,'getWaveforms',true,@islogical); % Get average waveforms? Only in effect for neurosuite/klustakwik format +addParameter(p,'getWaveforms',true,@islogical); % Get average waveforms? +addParameter(p,'useNeurosuiteWaveforms',false,@islogical); % Use Waveform features from spk files or load waveforms from dat file +addParameter(p,'spikes',[],@isstruct); % Load existing spikes structure to append new spike info +addParameter(p,'basepath',pwd,@ischar); % path to dat file, used to extract the waveforms from the dat file +addParameter(p,'LSB',0.195,@isnumeric); % Least significant bit (LSB in uV) Intan = 0.195, Amplipex = 0.3815 +addParameter(p,'session',[],@isstruct); % A Buzsaki db session struct +addParameter(p,'basename','',@ischar); % The baseName file naming convention + parse(p,varargin{:}) shanks = p.Results.shanks; @@ -25,6 +34,17 @@ forceReload = p.Results.forceReload; saveMat = p.Results.saveMat; getWaveforms = p.Results.getWaveforms; +spikes = p.Results.spikes; +basepath = p.Results.basepath; +useNeurosuiteWaveforms = p.Results.useNeurosuiteWaveforms; +LSB = p.Results.LSB; +session = p.Results.session; +baseName = p.Results.basename; + +if isempty(baseName) & ~isempty(basepath) + [~,baseName,~] = fileparts(basepath); + disp(['Using basepath to determine the basename: ' baseName]) +end if exist(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) & ~forceReload load(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) @@ -32,7 +52,7 @@ forceReload = true; disp('spikes.mat structure not up to date. Reloading spikes.') else - disp('Loading existing spikes file') + disp('loadClusteringData: Loading existing spikes file') end else forceReload = true; @@ -41,7 +61,7 @@ if forceReload switch lower(clusteringMethod) case {'klustakwik', 'neurosuite'} - disp('Loading Klustakwik clustered data') + disp('loadClusteringData: Loading Klustakwik clustered data') unit_nb = 0; spikes = []; shanks_new = []; @@ -54,16 +74,16 @@ end shanks = sort(shanks_new); end + xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); for shank = shanks disp(['Loading shank #' num2str(shank) '/' num2str(length(shanks)) ]) if ~raw_clusters - xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); cluster_index = load(fullfile(clusteringPath, [baseName '.clu.' num2str(shank)])); time_stamps = load(fullfile(clusteringPath,[baseName '.res.' num2str(shank)])); if getWaveforms fname = fullfile(clusteringPath,[baseName '.spk.' num2str(shank)]); f = fopen(fname,'r'); - waveforms = 0.000195 * double(fread(f,'int16')); + waveforms = LSB * double(fread(f,'int16')); samples = size(waveforms,1)/size(time_stamps,1); electrodes = size(xml.ElecGp{shank},2); waveforms = reshape(waveforms, [electrodes,samples/electrodes,length(waveforms)/samples]); @@ -84,7 +104,7 @@ spikes.cluID(unit_nb) = nb_clusters2(i); spikes.cluster_index(unit_nb) = nb_clusters2(i); spikes.total(unit_nb) = length(spikes.ts{unit_nb}); - if getWaveforms + if getWaveforms & useNeurosuiteWaveforms spikes.filtWaveform_all{unit_nb} = mean(waveforms(:,:,cluster_index == nb_clusters2(i)),3); spikes.filtWaveform_all_std{unit_nb} = permute(std(permute(waveforms(:,:,cluster_index == nb_clusters2(i)),[3,1,2])),[2,3,1]); [~,index1] = max(max(spikes.filtWaveform_all{unit_nb}') - min(spikes.filtWaveform_all{unit_nb}')); @@ -95,12 +115,17 @@ spikes.peakVoltage(unit_nb) = max(spikes.filtWaveform{unit_nb}) - min(spikes.filtWaveform{unit_nb}); end end + if getWaveforms + spikes.processinginfo.params.WaveformsSource = 'spk files'; + end + end + if getWaveforms & ~useNeurosuiteWaveforms + spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); end - clear cluster_index time_stamps case 'phy' - disp('Loading Phy clustered data') + disp('loadClusteringData: Loading Phy/Kilosort clustered data') xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); spike_cluster_index = readNPY(fullfile(clusteringPath, 'spike_clusters.npy')); spike_times = readNPY(fullfile(clusteringPath, 'spike_times.npy')); @@ -133,7 +158,6 @@ fileID = fopen(filename,'r'); dataArray = textscan(fileID, formatSpec, 'Delimiter', delimiter, 'HeaderLines' ,startRow-1, 'ReturnOnError', false); fclose(fileID); - spikes = []; j = 1; for i = 1:length(dataArray{1}) if raw_clusters == 0 @@ -146,9 +170,15 @@ spikes.UID(j) = j; if exist('cluster_ids') cluster_id = find(cluster_ids == spikes.cluID(j)); - spikes.shankID(j) = double(unit_shanks(cluster_id)); spikes.maxWaveformCh1(j) = double(peak_channel(cluster_id)); % index 1; spikes.maxWaveformCh(j) = double(peak_channel(cluster_id))-1; % index 0; + + % Assigning shankID to the unit + for jj = 1:size(xml.AnatGrps,2) + if any(xml.AnatGrps(jj).Channels == spikes.maxWaveformCh(j)) + spikes.shankID(j) = jj; + end + end end spikes.total(j) = length(spikes.ts{j}); spikes.amplitudes{j} = double(spike_amplitudes(spikes.ids{j})); @@ -166,74 +196,11 @@ end if getWaveforms % get waveforms - timerVal = tic; - nPull = 1000; % number of spikes to pull out - wfWin = 0.006; % Larger size of waveform windows for filterning - wfWinKeep = 0.001; - filtFreq = 500; - hpFilt = designfilt('highpassiir','FilterOrder',3, 'PassbandFrequency',filtFreq,'PassbandRipple',0.1, 'SampleRate',xml.SampleRate); - [b1, a1] = butter(3, filtFreq/xml.SampleRate*2, 'high'); - - f = waitbar(0,'Getting waveforms...'); - wfWin = round((wfWin * xml.SampleRate)/2); - - for ii = 1 : size(spikes.times,2) - waitbar(ii/size(spikes.times,2),f,['Pulling out waveforms (',num2str(ii),'/',num2str(size(spikes.times,2)),')']); - spkTmp = spikes.ts{ii}; - if length(spkTmp) > nPull - spkTmp = spkTmp(randperm(length(spkTmp))); - spkTmp = sort(spkTmp(1:nPull)); - end - - - % Determines the maximum waveform channel - wf = zeros((wfWin * 2)+1,xml.nChannels,100); - wfF = zeros((wfWin * 2)+1,xml.nChannels); - for jj = 1 : 100 - wf(:,:,jj) = double(LoadBinary([baseName '.dat'],'offset',spkTmp(jj) - wfWin,'nChannels',xml.nChannels,'precision','int16','frequency',xml.SampleRate,'samples',(wfWin * 2)+1)); - % wf(:,:,jj) = double(bz_LoadBinary([baseName '.dat'],'offset',spkTmp(jj) - (wfWin), 'samples',(wfWin * 2)+1,'frequency',xml.SampleRate,'nChannels',xml.nChannels)); - end - wf = 0.195 * mean(wf,3); - - for jj = 1 : size(wf,2) - wfF(:,jj) = filtfilt(b1, a1, wf(:,jj)); - end - [~, spikes.maxWaveformCh1(ii)] = max(abs(wfF(wfWin,:))); - spikes.maxWaveformCh(ii) = spikes.maxWaveformCh1(ii)-1; - - % Assigning shankID to the unit - for jj = 1:size(xml.AnatGrps,2) - if xml.AnatGrps(jj).Channels == spikes.maxWaveformCh(ii) - spikes.shankID(ii) = jj; - end - end - - % Pulls the waveforms from the dat - wf = zeros((wfWin * 2)+1,length(spkTmp)); - wfF = zeros((wfWin * 2)+1,length(spkTmp)); - for jj = 1 : length(spkTmp) - wf(:,jj) = double(LoadBinary([baseName '.dat'],'offset',spkTmp(jj) - wfWin,'nChannels',xml.nChannels,'channels',spikes.maxWaveformCh1(ii),'precision','int16','frequency',xml.SampleRate,'samples',(wfWin * 2)+1)); - wfF(:,jj) = filtfilt(b1, a1, wf(:,jj)); - end - wf2 = mean(0.195 * wf,2); - rawWaveform{ii} = detrend(wf2 - mean(wf2))'; - rawWaveform_std{ii} = std(0.195 * (wf-mean(wf))'); - filtWaveform{ii} = mean(0.195 * wfF,2)'; - filtWaveform_std{ii} = std(0.195 * wfF'); - - spikes.rawWaveform{ii} = rawWaveform{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); % keep only +- 1ms of waveform - spikes.rawWaveform_std{ii} = rawWaveform_std{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); - spikes.filtWaveform{ii} = filtWaveform{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); - spikes.filtWaveform_std{ii} = filtWaveform_std{ii}(wfWin-(wfWinKeep*xml.SampleRate):wfWin+(wfWinKeep*xml.SampleRate)); - spikes.timeWaveform{ii} = (-wfWinKeep:1/xml.SampleRate:wfWinKeep)*1000; - spikes.peakVoltage(ii) = max(spikes.filtWaveform{ii})-min(spikes.filtWaveform{ii}); - end - close(f) - toc(timerVal) + spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); end case 'klustaViewa' - disp('Loading KlustaViewa clustered data') + disp('loadClusteringData: Loading KlustaViewa clustered data') units_to_exclude = []; [spikes,~] = ImportKwikFile(baseName,clusteringPath,shanks,0,units_to_exclude); end @@ -247,22 +214,165 @@ end if spikes.numcells>0 - alltimes = cat(1,spikes.times{:}); groups = cat(1,groups{:}); %from cell to array - [alltimes,sortidx] = sort(alltimes); groups = groups(sortidx); %sort both + alltimes = cat(1,spikes.times{:}); groups = cat(1,groups{:}); % from cell to array + [alltimes,sortidx] = sort(alltimes); groups = groups(sortidx); % sort both spikes.spindices = [alltimes groups]; end % Attaching info about how the spikes structure was generated spikes.processinginfo.function = 'loadClusteringData'; - spikes.processinginfo.version = 3.2; + spikes.processinginfo.version = 3.4; spikes.processinginfo.date = now; spikes.processinginfo.params.forceReload = forceReload; spikes.processinginfo.params.shanks = shanks; spikes.processinginfo.params.raw_clusters = raw_clusters; spikes.processinginfo.params.getWaveforms = getWaveforms; + spikes.processinginfo.params.baseName = baseName; + spikes.processinginfo.params.clusteringMethod = clusteringMethod; + spikes.processinginfo.params.clusteringPath = clusteringPath; + spikes.processinginfo.params.basepath = basepath; + spikes.processinginfo.params.useNeurosuiteWaveforms = useNeurosuiteWaveforms; % Saving output to a buzcode compatible spikes file. if saveMat + disp('loadClusteringData: Saving spikes') save(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat']),'spikes') end end + +end + +function spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session) +% Requires a neurosuite xml structure. Bad channels must be removed from the spike groups beforehand +showWaveforms = true; +badChannels = []; +if ~isempty(session) + badChannels = session.channelTags.Bad.channels; + if ~isempty(session.channelTags.Bad.spikeGroups) + badChannels = [badChannels,session.extracellular.spikeGroups(session.channelTags.Bad.spikeGroups)+1]; + end + badChannels = unique(badChannels); +end + +badChannels = [badChannels,setdiff([xml.AnatGrps.Channels],[xml.SpkGrps.Channels])+1]; +goodChannels = setdiff(1:xml.nChannels,badChannels); +nGoodChannels = length(goodChannels); + +timerVal = tic; +nPull = 600; % number of spikes to pull out +wfWin_sec = 0.004; % Larger size of waveform windows for filterning. total width in ms +wfWinKeep = 0.0008; % half width in ms +filtFreq = [500,8000]; +[b1, a1] = butter(3, filtFreq/xml.SampleRate*2, 'bandpass'); + +f = waitbar(0,['Getting waveforms from dat file'],'Name',['Processing ' baseName]); +if showWaveforms + fig1 = figure('Name', ['Getting waveforms for ' baseName],'NumberTitle', 'off'); +end +wfWin = round((wfWin_sec * xml.SampleRate)/2); +t1 = toc(timerVal); +s = dir(fullfile(basepath,[baseName '.dat'])); +duration = s.bytes/(2*xml.nChannels*xml.SampleRate); +m = memmapfile(fullfile(basepath,[baseName '.dat']),'Format','int16','writable',false); +DATA = m.Data; + +for ii = 1 : size(spikes.times,2) + if ishandle(f) + waitbar(ii/size(spikes.times,2),f,['Waveforms: ',num2str(ii),'/',num2str(size(spikes.times,2)),'. ', num2str(round(toc(timerVal)-t1)),' sec/unit, ', num2str(round(toc(timerVal)/60)) ' minutes total']); + else + disp('Canceling waveform extraction...') + clear rawWaveform rawWaveform_std filtWaveform filtWaveform_std + clear DATA + clear m + error('Waveform extraction canceled by user') + end + t1 = toc(timerVal); + spkTmp = spikes.ts{ii}(find(spikes.times{ii} > wfWin_sec/1.8 & spikes.times{ii} < duration-wfWin_sec/1.8)); + + if length(spkTmp) > nPull + spkTmp = spkTmp(randperm(length(spkTmp))); + spkTmp = sort(spkTmp(1:nPull)); + end + + % Determines the maximum waveform channel + startIndicies = (spkTmp(1:min(100,length(spkTmp))) - wfWin)*xml.nChannels+1; + stopIndicies = (spkTmp(1:min(100,length(spkTmp))) + wfWin)*xml.nChannels; + X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1); +% temp1 = reshape(double(m.Data(X(1:end-1))),xml.nChannels,(wfWin*2),[]); + wf = LSB * mean(reshape(double(DATA(X(1:end-1))),xml.nChannels,(wfWin*2),[]),3); + wfF2 = zeros((wfWin * 2),nGoodChannels); + for jj = 1 : nGoodChannels + wfF2(:,jj) = filtfilt(b1, a1, wf(goodChannels(jj),:)); + end + [~, idx] = max(max(wfF2)-min(wfF2)); % max(abs(wfF(wfWin,:))); + spikes.maxWaveformCh1(ii) = goodChannels(idx); + spikes.maxWaveformCh(ii) = spikes.maxWaveformCh1(ii)-1; + + % Assigning shankID to the unit + for jj = 1:size(xml.AnatGrps,2) + if any(xml.AnatGrps(jj).Channels == spikes.maxWaveformCh(ii)) + spikes.shankID(ii) = jj; + end + end + + % Pulls the waveforms from the dat + startIndicies = (spkTmp - wfWin+1); + stopIndicies = (spkTmp + wfWin); + X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1); + X = X(1:end-1) * xml.nChannels+spikes.maxWaveformCh1(ii); + + wf = LSB * double(reshape(DATA(X),wfWin*2,length(spkTmp))); + wfF = zeros((wfWin * 2),length(spkTmp)); + for jj = 1 : length(spkTmp) + wfF(:,jj) = filtfilt(b1, a1, wf(:,jj)); + end + + wf2 = mean(wf,2); + rawWaveform = detrend(wf2 - mean(wf2))'; + rawWaveform_std = std((wf-mean(wf))'); + filtWaveform = mean(wfF,2)'; + filtWaveform_std = std(wfF'); + + window_interval = wfWin-(wfWinKeep*xml.SampleRate):wfWin-1+(wfWinKeep*xml.SampleRate); + spikes.rawWaveform{ii} = rawWaveform(window_interval); % keep only +- 1ms of waveform + spikes.rawWaveform_std{ii} = rawWaveform_std(window_interval); + spikes.filtWaveform{ii} = filtWaveform(window_interval); + spikes.filtWaveform_std{ii} = filtWaveform_std(window_interval); + spikes.timeWaveform{ii} = (-wfWinKeep+1/xml.SampleRate:1/xml.SampleRate:wfWinKeep)*1000; + spikes.peakVoltage(ii) = max(spikes.filtWaveform{ii})-min(spikes.filtWaveform{ii}); + + if ishandle(fig1) + figure(fig1) + subplot(2,2,1), hold off + plot(wfF2), hold on, plot(wfF2(:,idx),'k','linewidth',2), title('Filt waveform across channels'), xlabel('Samples'), hold off + + subplot(2,2,2), hold off, + plot(wfF), title('Peak channel waveforms'), xlabel('Samples') + + subplot(2,2,3), hold on, + plot(spikes.timeWaveform{ii},spikes.rawWaveform{ii}), title('Raw waveform'), xlabel('Time (ms)') + xlim([-0.8,0.8]) + subplot(2,2,4), hold on, + plot(spikes.timeWaveform{ii},spikes.filtWaveform{ii}), title('Filtered waveform'), xlabel('Time (ms)') + xlim([-0.8,0.8]) + end + clear wf wfF wf2 wfF2 +end +if ishandle(f) + spikes.processinginfo.params.WaveformsSource = 'dat file'; + spikes.processinginfo.params.WaveformsFiltFreq = filtFreq; + spikes.processinginfo.params.Waveforms_nPull = nPull; + spikes.processinginfo.params.WaveformsWin_sec = wfWin_sec; + spikes.processinginfo.params.WaveformsWinKeep = wfWinKeep; + spikes.processinginfo.params.WaveformsFilterType = 'butter'; + clear rawWaveform rawWaveform_std filtWaveform filtWaveform_std + clear DATA + clear m + waitbar(ii/size(spikes.times,2),f,['Waveform extraction complete ',num2str(ii),'/',num2str(size(spikes.times,2)),'. ', num2str(round(toc(timerVal)/60)) ' minutes total']); + disp(['Waveform extraction complete. Total duration: ' num2str(round(toc(timerVal)/60)),' minutes']) + if ishandle(fig1) + set(fig1,'Name',['Waveform extraction complete for ' baseName]) + end + % close(f) +end +end \ No newline at end of file From cda0f4de56400d3ad410fbf331afffed07bb7a5c Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Fri, 5 Jul 2019 11:14:06 -0400 Subject: [PATCH 29/35] New load spikes script --- KiloSortWrapper.m | 2 +- loadClusteringData.m => loadSpikes.m | 165 ++++++++++++++++++++------- 2 files changed, 124 insertions(+), 43 deletions(-) rename loadClusteringData.m => loadSpikes.m (73%) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index eca8427..7a648ae 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -151,7 +151,7 @@ basename = rez.ops.basename; rez.ops.fbinary = fullfile(pwd, [basename,'.dat']); Kilosort2Neurosuite(rez) - + writeNPY(rez.ops.kcoords, fullfile(clustering_path, 'channel_shanks.npy')); phy_export_units(clustering_path,basename); diff --git a/loadClusteringData.m b/loadSpikes.m similarity index 73% rename from loadClusteringData.m rename to loadSpikes.m index 0dbcb53..2e6d152 100644 --- a/loadClusteringData.m +++ b/loadSpikes.m @@ -1,45 +1,78 @@ -function spikes = loadClusteringData(clusteringPath,clusteringMethod,varargin) -% Load clustered data from multiple pipelines [Current options: Phy, Klustakwik/Neurosuite] +function spikes = loadSpikes(varargin) +% Load clustered data from multiple pipelines [Current options: Phy, Klustakwik/Neurosuite,klustaViewa] % Buzcode compatible output. Saves output to a basename.spikes.cellinfo.mat file -% clusteringPath: path to the clustered data -% clusteringMethod: clustering method to handle different pipelines: ['phy','klustakwik'/'neurosuite'] - +% +% INPUTS +% % See description of varargin below +% +% OUTPUT +% +% spikes: - Matlab struct following the buzcode standard (https://github.com/buzsakilab/buzcode) +% .sessionName - Name of recording file +% .UID - Unique identifier for each neuron in a recording +% .times - Cell array of timestamps (seconds) for each neuron +% .spindices - Sorted vector of [spiketime UID], useful as input to some functions and plotting rasters +% .region - Region ID for each neuron (especially important large scale, high density probes) +% .maxWaveformCh - Channel # with largest amplitude spike for each neuron (0-indexed) +% .maxWaveformCh1 - Channel # with largest amplitude spike for each neuron (1-indexed) +% .rawWaveform - Average waveform on maxWaveformCh (from raw .dat) +% .filtWaveform - Average filtered waveform on maxWaveformCh (from raw .dat) +% .rawWaveform_std - Average waveform on maxWaveformCh (from raw .dat) +% .filtWaveform_std - Average filtered waveform on maxWaveformCh (from raw .dat) +% .peakVoltage - Peak voltage (uV) +% .cluID - Cluster ID +% .shankID - shankID +% .processingInfo - Processing info +% +% DEPENDENCIES: +% +% LoadXml.m & xmltools.m (default) or bz_getSessionInfo.m -% by Peter Petersen +% By Peter Petersen % petersen.peter@gmail.com +% Last edited: 20-06-2019 % Version history % 3.2 waveforms for phy data extracted from the raw dat % 3.3 waveforms extracted from raw dat using memmap function. Interval and bad channels bugs fixed as well % 3.4 bug fix which gave misaligned waveform extraction from raw dat. Plot improvements of waveforms +% 3.5 new name and better handling of inputs p = inputParser; +addParameter(p,'basepath',pwd,@ischar); % basepath with dat file, used to extract the waveforms from the dat file +addParameter(p,'clusteringpath','',@ischar); % clustering path to spike data +addParameter(p,'clusteringformat','Phy',@ischar); % clustering format: [Current options: Phy, Klustakwik/Neurosuite,klustaViewa] +addParameter(p,'basename','',@ischar); % The baseName file naming convention addParameter(p,'shanks',nan,@isnumeric); % shanks: Loading only a subset of shanks (only applicable to Klustakwik) addParameter(p,'raw_clusters',false,@islogical); % raw_clusters: Load only a subset of clusters (might not work anymore as it has not been tested for a long time) -addParameter(p,'forceReload',false,@islogical); % Reload spikes from original format and resave the .spikes.mat file? addParameter(p,'saveMat',true,@islogical); % Save spikes to mat file? +addParameter(p,'forceReload',false,@islogical); % Reload spikes from original format (overwrites existing mat file if saveMat==true) addParameter(p,'getWaveforms',true,@islogical); % Get average waveforms? -addParameter(p,'useNeurosuiteWaveforms',false,@islogical); % Use Waveform features from spk files or load waveforms from dat file +addParameter(p,'useNeurosuiteWaveforms',false,@islogical); % Use Waveform features from spk files. Alternatively it loads waveforms from dat file (Klustakwik specific) addParameter(p,'spikes',[],@isstruct); % Load existing spikes structure to append new spike info -addParameter(p,'basepath',pwd,@ischar); % path to dat file, used to extract the waveforms from the dat file -addParameter(p,'LSB',0.195,@isnumeric); % Least significant bit (LSB in uV) Intan = 0.195, Amplipex = 0.3815 -addParameter(p,'session',[],@isstruct); % A Buzsaki db session struct -addParameter(p,'basename','',@ischar); % The baseName file naming convention +addParameter(p,'LSB',0.195,@isnumeric); % Least significant bit (LSB in uV) Intan = 0.195, Amplipex = 0.3815. (range/precision) +addParameter(p,'session',[],@isstruct); % A buzsaki lab db session struct +addParameter(p,'buzcode',false,@islogical); % If true, uses bz_getSessionInfo. Otherwise uses LoadXml + parse(p,varargin{:}) +basepath = p.Results.basepath; +clusteringPath = p.Results.clusteringpath; +clusteringFormat = p.Results.clusteringformat; +baseName = p.Results.basename; shanks = p.Results.shanks; raw_clusters = p.Results.raw_clusters; forceReload = p.Results.forceReload; saveMat = p.Results.saveMat; getWaveforms = p.Results.getWaveforms; spikes = p.Results.spikes; -basepath = p.Results.basepath; useNeurosuiteWaveforms = p.Results.useNeurosuiteWaveforms; LSB = p.Results.LSB; session = p.Results.session; -baseName = p.Results.basename; + +buzcode = p.Results.buzcode; if isempty(baseName) & ~isempty(basepath) [~,baseName,~] = fileparts(basepath); @@ -48,20 +81,30 @@ if exist(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) & ~forceReload load(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) - if isfield(spikes,'ts') && (~isfield(spikes,'processinginfo') || (isfield(spikes,'processinginfo') && spikes.processinginfo.version < 3 && strcmp(spikes.processinginfo.function,'loadClusteringData') )) + if isfield(spikes,'ts') && (~isfield(spikes,'processinginfo') || (isfield(spikes,'processinginfo') && spikes.processinginfo.version < 3 && strcmp(spikes.processinginfo.function,'loadSpikes') )) forceReload = true; disp('spikes.mat structure not up to date. Reloading spikes.') else - disp('loadClusteringData: Loading existing spikes file') + disp('loadSpikes: Loading existing spikes file') end else forceReload = true; end +% Loading spikes if forceReload - switch lower(clusteringMethod) + % Loading session info + if buzcode + xml = bz_getSessionInfo(basepath, 'noPrompts', true); + xml.SampleRate = xml.rates.wideband; + else + xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); + end + switch lower(clusteringFormat) + + % Loading klustakwik case {'klustakwik', 'neurosuite'} - disp('loadClusteringData: Loading Klustakwik clustered data') + disp('loadSpikes: Loading Klustakwik data') unit_nb = 0; spikes = []; shanks_new = []; @@ -74,7 +117,6 @@ end shanks = sort(shanks_new); end - xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); for shank = shanks disp(['Loading shank #' num2str(shank) '/' num2str(length(shanks)) ]) if ~raw_clusters @@ -124,16 +166,17 @@ end clear cluster_index time_stamps + % Loading phy case 'phy' - disp('loadClusteringData: Loading Phy/Kilosort clustered data') - xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); + disp('loadSpikes: Loading Phy/Kilosort data') + spike_cluster_index = readNPY(fullfile(clusteringPath, 'spike_clusters.npy')); spike_times = readNPY(fullfile(clusteringPath, 'spike_times.npy')); spike_amplitudes = readNPY(fullfile(clusteringPath, 'amplitudes.npy')); spike_clusters = unique(spike_cluster_index); filename1 = fullfile(clusteringPath,'cluster_group.tsv'); filename2 = fullfile(clusteringPath,'cluster_groups.csv'); - if exist(fullfile(clusteringPath, 'cluster_ids.npy')) + if exist(fullfile(clusteringPath, 'cluster_ids.npy')) && exist(fullfile(clusteringPath, 'shanks.npy')) && exist(fullfile(clusteringPath, 'peak_channel.npy')) cluster_ids = readNPY(fullfile(clusteringPath, 'cluster_ids.npy')); unit_shanks = readNPY(fullfile(clusteringPath, 'shanks.npy')); peak_channel = readNPY(fullfile(clusteringPath, 'peak_channel.npy'))+1; @@ -188,6 +231,7 @@ else spikes.ids{j} = find(spike_cluster_index == dataArray{1}(i)); spikes.ts{j} = double(spike_times(spikes.ids{j})); + spikes.times{j} = spikes.ts{j}/xml.SampleRate; spikes.cluID(j) = dataArray{1}(i); spikes.UID(j) = j; spikes.amplitudes{j} = double(spike_amplitudes(spikes.ids{j}))'; @@ -195,16 +239,41 @@ end end - if getWaveforms % get waveforms + if getWaveforms % gets waveforms from dat file spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); end - + + % Loading klustaViewa - Kwik format (Klustasuite 0.3.0.beta4) case 'klustaViewa' - disp('loadClusteringData: Loading KlustaViewa clustered data') - units_to_exclude = []; - [spikes,~] = ImportKwikFile(baseName,clusteringPath,shanks,0,units_to_exclude); + disp('loadSpikes: Loading KlustaViewa data') + shank_nb = 1; + for shank = 1:shanks + spike_times = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/time_samples'])); + recording_nb = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/recording'])); + cluster_index = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/clusters/main'])); + waveforms = double(hdf5read([folder, dataset, '.kwx'], ['/channel_groups/' num2str(shank-1) '/waveforms_filtered'])); + clusters = unique(cluster_index); + for i = 1:length(clusters(:)) + cluster_type = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/clusters/main/' num2str(clusters(i)),'/'],'cluster_group')); + if cluster_type == 2 + indexes{shank_nb} = shank_nb*ones(sum(cluster_index == clusters(i)),1); + spikes.UID(shank_nb) = shank_nb; + spikes.ts{shank_nb} = spike_times(cluster_index == clusters(i))+recording_nb(cluster_index == clusters(i))*40*40000; + spikes.times{shank_nb} = spikes.ts{j}/xml.SampleRate; + spikes.total(shank_nb) = sum(cluster_index == clusters(i)); + spikes.shankID(shank_nb) = shank-1; + spikes.cluID(shank_nb) = clusters(i); + spikes.filtWaveform_all{shank_nb} = mean(waveforms(:,:,cluster_index == clusters(i)),3); + spikes.filtWaveform_all_std{shank_nb} = permute(std(permute(waveforms(:,:,cluster_index == clusters(i)),[3,1,2])),[2,3,1]); + shank_nb = shank_nb+1; + end + end + end + + if getWaveforms % get waveforms + spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); + end end - spikes.sessionName = baseName; % Generate spindices matrics @@ -220,22 +289,22 @@ end % Attaching info about how the spikes structure was generated - spikes.processinginfo.function = 'loadClusteringData'; - spikes.processinginfo.version = 3.4; + spikes.processinginfo.function = 'loadSpikes'; + spikes.processinginfo.version = 3.5; spikes.processinginfo.date = now; spikes.processinginfo.params.forceReload = forceReload; spikes.processinginfo.params.shanks = shanks; spikes.processinginfo.params.raw_clusters = raw_clusters; spikes.processinginfo.params.getWaveforms = getWaveforms; spikes.processinginfo.params.baseName = baseName; - spikes.processinginfo.params.clusteringMethod = clusteringMethod; + spikes.processinginfo.params.clusteringFormat = clusteringFormat; spikes.processinginfo.params.clusteringPath = clusteringPath; spikes.processinginfo.params.basepath = basepath; spikes.processinginfo.params.useNeurosuiteWaveforms = useNeurosuiteWaveforms; % Saving output to a buzcode compatible spikes file. if saveMat - disp('loadClusteringData: Saving spikes') + disp('loadSpikes: Saving spikes') save(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat']),'spikes') end end @@ -243,26 +312,38 @@ end function spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session) -% Requires a neurosuite xml structure. Bad channels must be removed from the spike groups beforehand +% Requires a neurosuite xml structure. +% Bad channels must be deselected in the spike groups, or skipped beforehand +timerVal = tic; +nPull = 600; % number of spikes to pull out (default: 600) +wfWin_sec = 0.004; % Larger size of waveform windows for filterning. total width in ms +wfWinKeep = 0.0008; % half width in ms +filtFreq = [500,8000]; showWaveforms = true; + badChannels = []; + +% Removing channels marked as Bad in session struct if ~isempty(session) badChannels = session.channelTags.Bad.channels; if ~isempty(session.channelTags.Bad.spikeGroups) badChannels = [badChannels,session.extracellular.spikeGroups(session.channelTags.Bad.spikeGroups)+1]; end - badChannels = unique(badChannels); + badChannels = unique(badChannels); end +% Removing channels that does not exist in SpkGrps badChannels = [badChannels,setdiff([xml.AnatGrps.Channels],[xml.SpkGrps.Channels])+1]; + +% Removing channels with skip parameter from the xml +if isfield(xml.AnatGrps,'Skip') + channelOrder = [xml.AnatGrps.Channels]+1; + skip = find([xml.AnatGrps.Skip]); + badChannels = [badChannels, channelOrder(skip)]; +end goodChannels = setdiff(1:xml.nChannels,badChannels); nGoodChannels = length(goodChannels); -timerVal = tic; -nPull = 600; % number of spikes to pull out -wfWin_sec = 0.004; % Larger size of waveform windows for filterning. total width in ms -wfWinKeep = 0.0008; % half width in ms -filtFreq = [500,8000]; [b1, a1] = butter(3, filtFreq/xml.SampleRate*2, 'bandpass'); f = waitbar(0,['Getting waveforms from dat file'],'Name',['Processing ' baseName]); @@ -298,7 +379,7 @@ startIndicies = (spkTmp(1:min(100,length(spkTmp))) - wfWin)*xml.nChannels+1; stopIndicies = (spkTmp(1:min(100,length(spkTmp))) + wfWin)*xml.nChannels; X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1); -% temp1 = reshape(double(m.Data(X(1:end-1))),xml.nChannels,(wfWin*2),[]); + % temp1 = reshape(double(m.Data(X(1:end-1))),xml.nChannels,(wfWin*2),[]); wf = LSB * mean(reshape(double(DATA(X(1:end-1))),xml.nChannels,(wfWin*2),[]),3); wfF2 = zeros((wfWin * 2),nGoodChannels); for jj = 1 : nGoodChannels @@ -334,7 +415,7 @@ filtWaveform_std = std(wfF'); window_interval = wfWin-(wfWinKeep*xml.SampleRate):wfWin-1+(wfWinKeep*xml.SampleRate); - spikes.rawWaveform{ii} = rawWaveform(window_interval); % keep only +- 1ms of waveform + spikes.rawWaveform{ii} = rawWaveform(window_interval); % keep only +- 0.8 ms of waveform spikes.rawWaveform_std{ii} = rawWaveform_std(window_interval); spikes.filtWaveform{ii} = filtWaveform(window_interval); spikes.filtWaveform_std{ii} = filtWaveform_std(window_interval); @@ -375,4 +456,4 @@ end % close(f) end -end \ No newline at end of file +end From 84b11c2d21c480c2613fc5d010808ae3c7aeeca9 Mon Sep 17 00:00:00 2001 From: Brendon Watson Date: Thu, 11 Jul 2019 22:11:52 -0400 Subject: [PATCH 30/35] Added a new probe/array type --- createChannelMapFile_KSW.m | 62 ++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/createChannelMapFile_KSW.m b/createChannelMapFile_KSW.m index dc7c1e7..75273bd 100644 --- a/createChannelMapFile_KSW.m +++ b/createChannelMapFile_KSW.m @@ -23,17 +23,22 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) electrode_type = 'poly3'; case 'poly5' electrode_type = 'poly5'; + case 'twohundred' + electrode_type = 'twohundred'; end + +%%Default if ~exist('electrode_type') electrode_type = 'staggered'; end -xcoords = []; + +%% +xcoords = [];%eventual output arrays ycoords = []; -t = par.AnatGrps; ngroups = length(par.AnatGrps); for g = 1:ngroups - tgroups{g} = par.AnatGrps(g).Channels; + groups{g} = par.AnatGrps(g).Channels; end switch(electrode_type) @@ -41,7 +46,7 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) for a= 1:ngroups %being super lazy and making this map with loops x = []; y = []; - tchannels = tgroups{a}; + tchannels = groups{a}; for i =1:length(tchannels) x(i) = 20;%length(tchannels)-i; y(i) = -i*20; @@ -56,7 +61,7 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) case 'poly3' disp('poly3 probe layout') for a= 1:ngroups %being super lazy and making this map with loops - tchannels = tgroups{a}; + tchannels = groups{a}; x = nan(1,length(tchannels)); y = nan(1,length(tchannels)); extrachannels = mod(length(tchannels),3); @@ -75,7 +80,7 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) case 'poly5' disp('poly5 probe layout') for a= 1:ngroups %being super lazy and making this map with loops - tchannels = tgroups{a}; + tchannels = groups{a}; x = nan(1,length(tchannels)); y = nan(1,length(tchannels)); extrachannels = mod(length(tchannels),5); @@ -101,7 +106,7 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) for a= 1:ngroups %being super lazy and making this map with loops x = []; y = []; - tchannels = tgroups{a}; + tchannels = groups{a}; for i =1:length(tchannels) x(i) = length(tchannels)-i; y(i) = -i*30; @@ -110,36 +115,53 @@ function createChannelMapFile_Local(basepath,basename,electrode_type) xcoords = cat(1,xcoords,x(:)); ycoords = cat(1,ycoords,y(:)); end + case 'twohundred' + for a= 1:ngroups + x = []; + y = []; + tchannels = groups{a}; + for i =1:length(tchannels) + x(i) = 0;%length(tchannels)-i; + if mod(i,2) + y(i) = 0;%odds + else + y(i) = 200;%evens + end + end + x = x+(a-1)*200; + xcoords = cat(1,xcoords,x(:)); + ycoords = cat(1,ycoords,y(:)); + end end Nchannels = length(xcoords); kcoords = zeros(1,Nchannels); switch(electrode_type) - case {'staggered','poly3','poly5'} + case {'staggered','poly3','poly5','twohundred'} for a= 1:ngroups - kcoords(tgroups{a}+1) = a; + kcoords(groups{a}+1) = a; end case 'neurogrid' for a= 1:ngroups - kcoords(tgroups{a}+1) = floor((a-1)/4)+1; + kcoords(groups{a}+1) = floor((a-1)/4)+1; end end connected = true(Nchannels, 1); -% Removing dead channels by the skip parameter in the xml +% just use AnatGrps +% % Removing dead channels by the skip parameter in the xml +% % order = [par.AnatGrps.Channels]; +% % skip = find([par.AnatGrps.Skip]); +% % connected(order(skip)+1) = false; % order = [par.AnatGrps.Channels]; -% skip = find([par.AnatGrps.Skip]); -% connected(order(skip)+1) = false; - -order = [par.AnatGrps.Channels]; -if isfield(par,'SpkGrps') - skip2 = find(~ismember([par.AnatGrps.Channels], [par.SpkGrps.Channels])); % finds the indices of the channels that are not part of SpkGrps - connected(order(skip2)+1) = false; -end +% if isfield(par,'SpkGrps') +% skip2 = find(~ismember([par.AnatGrps.Channels], [par.SpkGrps.Channels])); % finds the indices of the channels that are not part of SpkGrps +% connected(order(skip2)+1) = false; +% end chanMap = 1:Nchannels; chanMap0ind = chanMap - 1; -[~,I] = sort(horzcat(tgroups{:})); +[~,I] = sort(horzcat(groups{:})); xcoords = xcoords(I)'; ycoords = ycoords(I)'; From 85b98052d301906484bfb9f1044572171ae7ea7a Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Wed, 31 Jul 2019 16:40:28 -0400 Subject: [PATCH 31/35] Update loadSpikes.m --- loadSpikes.m | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/loadSpikes.m b/loadSpikes.m index 2e6d152..2e1ee18 100644 --- a/loadSpikes.m +++ b/loadSpikes.m @@ -31,7 +31,7 @@ % By Peter Petersen % petersen.peter@gmail.com -% Last edited: 20-06-2019 +% Last edited: 31-07-2019 % Version history % 3.2 waveforms for phy data extracted from the raw dat @@ -89,6 +89,7 @@ end else forceReload = true; + spikes = []; end % Loading spikes @@ -106,7 +107,6 @@ case {'klustakwik', 'neurosuite'} disp('loadSpikes: Loading Klustakwik data') unit_nb = 0; - spikes = []; shanks_new = []; if isnan(shanks) fileList = dir(fullfile(clusteringPath,[baseName,'.res.*'])); @@ -122,7 +122,7 @@ if ~raw_clusters cluster_index = load(fullfile(clusteringPath, [baseName '.clu.' num2str(shank)])); time_stamps = load(fullfile(clusteringPath,[baseName '.res.' num2str(shank)])); - if getWaveforms + if getWaveforms & useNeurosuiteWaveforms fname = fullfile(clusteringPath,[baseName '.spk.' num2str(shank)]); f = fopen(fname,'r'); waveforms = LSB * double(fread(f,'int16')); @@ -242,7 +242,7 @@ if getWaveforms % gets waveforms from dat file spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); end - + % Loading klustaViewa - Kwik format (Klustasuite 0.3.0.beta4) case 'klustaViewa' disp('loadSpikes: Loading KlustaViewa data') From 8a9cf2a178642439875d952e277841864d6ab11d Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Wed, 31 Jul 2019 16:40:37 -0400 Subject: [PATCH 32/35] Update PhyAutoClustering.m --- PhyAutoClustering.m | 3 +++ 1 file changed, 3 insertions(+) diff --git a/PhyAutoClustering.m b/PhyAutoClustering.m index 4397e34..bf82148 100644 --- a/PhyAutoClustering.m +++ b/PhyAutoClustering.m @@ -159,6 +159,9 @@ function PhyAutoClustering(clusteringpath,varargin) meanR = [meanR; meanR_cur]; maxPwRatio_cur = max(abs(wav(11,:)))/mean(abs(wav(11,:))); + if isempty(maxPwRatio_cur) + maxPwRatio = [maxPwRatio; 0]; + end maxPwRatio = [maxPwRatio; maxPwRatio_cur]; [ccgR,t] = CCG(spktime,ones(size(spktime)),'binsize',.0005,'duration',.06); From 160f08e544b3a6311e25d94a5dd38c07772902ff Mon Sep 17 00:00:00 2001 From: Peter Petersen Date: Sun, 13 Oct 2019 12:37:11 -0400 Subject: [PATCH 33/35] Update loadSpikes.m --- loadSpikes.m | 50 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/loadSpikes.m b/loadSpikes.m index 2e1ee18..e41d6de 100644 --- a/loadSpikes.m +++ b/loadSpikes.m @@ -94,17 +94,19 @@ % Loading spikes if forceReload - % Loading session info - if buzcode - xml = bz_getSessionInfo(basepath, 'noPrompts', true); - xml.SampleRate = xml.rates.wideband; - else - xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); - end switch lower(clusteringFormat) % Loading klustakwik case {'klustakwik', 'neurosuite'} + % Loading session info + if buzcode + xml = bz_getSessionInfo(basepath, 'noPrompts', true); + xml.SampleRate = xml.rates.wideband; + elseif exist(fullfile(clusteringPath,[baseName, '.xml']),'file') + xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); + else + error(['xml file does not exist: ', fullfile(clusteringPath,[baseName, '.xml'])]) + end disp('loadSpikes: Loading Klustakwik data') unit_nb = 0; shanks_new = []; @@ -166,8 +168,17 @@ end clear cluster_index time_stamps - % Loading phy + % Loading phy case 'phy' + % Loading session info + if buzcode + xml = bz_getSessionInfo(basepath, 'noPrompts', true); + xml.SampleRate = xml.rates.wideband; + elseif exist(fullfile(clusteringPath,[baseName, '.xml']),'file') + xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); + else + error(['xml file does not exist: ', fullfile(clusteringPath,[baseName, '.xml'])]) + end disp('loadSpikes: Loading Phy/Kilosort data') spike_cluster_index = readNPY(fullfile(clusteringPath, 'spike_clusters.npy')); @@ -244,22 +255,25 @@ end % Loading klustaViewa - Kwik format (Klustasuite 0.3.0.beta4) - case 'klustaViewa' + case {'klustaViewa','kwik'} disp('loadSpikes: Loading KlustaViewa data') + if isnan(shanks) + error('Please provide the number of shanks for the session') + end shank_nb = 1; for shank = 1:shanks - spike_times = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/time_samples'])); - recording_nb = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/recording'])); - cluster_index = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/clusters/main'])); - waveforms = double(hdf5read([folder, dataset, '.kwx'], ['/channel_groups/' num2str(shank-1) '/waveforms_filtered'])); + spike_times = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/spikes/time_samples'])); + recording_nb = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/spikes/recording'])); + cluster_index = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/spikes/clusters/main'])); + waveforms = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwx']), ['/channel_groups/' num2str(shank-1) '/waveforms_filtered'])); clusters = unique(cluster_index); for i = 1:length(clusters(:)) - cluster_type = double(hdf5read([folder, dataset, '.kwik'], ['/channel_groups/' num2str(shank-1) '/clusters/main/' num2str(clusters(i)),'/'],'cluster_group')); + cluster_type = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/clusters/main/' num2str(clusters(i)),'/'],'cluster_group')); if cluster_type == 2 indexes{shank_nb} = shank_nb*ones(sum(cluster_index == clusters(i)),1); spikes.UID(shank_nb) = shank_nb; spikes.ts{shank_nb} = spike_times(cluster_index == clusters(i))+recording_nb(cluster_index == clusters(i))*40*40000; - spikes.times{shank_nb} = spikes.ts{j}/xml.SampleRate; + spikes.times{shank_nb} = spikes.ts{shank_nb}/40000; spikes.total(shank_nb) = sum(cluster_index == clusters(i)); spikes.shankID(shank_nb) = shank-1; spikes.cluID(shank_nb) = clusters(i); @@ -270,9 +284,9 @@ end end - if getWaveforms % get waveforms - spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); - end +% if getWaveforms % get waveforms +% spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); +% end end spikes.sessionName = baseName; From 8d0c0d234d05c786359d73e164990ca6afbf2837 Mon Sep 17 00:00:00 2001 From: Peter Petersen Date: Sun, 13 Oct 2019 12:38:51 -0400 Subject: [PATCH 34/35] Update loadSpikes.m --- loadSpikes.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loadSpikes.m b/loadSpikes.m index e41d6de..65d06c1 100644 --- a/loadSpikes.m +++ b/loadSpikes.m @@ -31,7 +31,7 @@ % By Peter Petersen % petersen.peter@gmail.com -% Last edited: 31-07-2019 +% Last edited: 13-10-2019 % Version history % 3.2 waveforms for phy data extracted from the raw dat From afa7c00fb17606ccf3d362c78e5e0d3fe82d494d Mon Sep 17 00:00:00 2001 From: petersenpeter Date: Sun, 24 Nov 2019 09:22:23 -0500 Subject: [PATCH 35/35] Updated loadSpikes --- KiloSortWrapper.m | 4 +- Phy2Neurosuite.m | 13 +-- createChannelMapFile_KSW.m | 4 +- loadSpikes.m | 164 +++++++++++++++++++------------------ 4 files changed, 96 insertions(+), 89 deletions(-) diff --git a/KiloSortWrapper.m b/KiloSortWrapper.m index 7a648ae..f0dd50b 100755 --- a/KiloSortWrapper.m +++ b/KiloSortWrapper.m @@ -13,7 +13,7 @@ % % Dependencies: KiloSort (https://github.com/cortex-lab/KiloSort) % -% The AutoClustering requires the CCGHeart to be compile. +% The AutoClustering requires CCGHeart to be compiled. % Go to the private folder of the wrapper and type: % mex -O CCGHeart.c % @@ -36,7 +36,7 @@ addParameter(p,'GPU_id',1,@isnumeric) % Specify the GPU_id addParameter(p,'SSD_path','K:\Kilosort',@ischar) % Path to SSD disk. Make it empty to disable SSD addParameter(p,'CreateSubdirectory',1,@isnumeric) % Puts the Kilosort output into a subfolder -addParameter(p,'performAutoCluster',0,@isnumeric) % Performs PhyAutoCluster once Kilosort is complete when exporting to Phy. +addParameter(p,'performAutoCluster',1,@isnumeric) % Performs PhyAutoCluster once Kilosort is complete when exporting to Phy. addParameter(p,'config','',@ischar) % Specify a configuration file to use from the ConfigurationFiles folder. e.g. 'Omid' parse(p,varargin{:}) diff --git a/Phy2Neurosuite.m b/Phy2Neurosuite.m index 101cb16..0a503fd 100644 --- a/Phy2Neurosuite.m +++ b/Phy2Neurosuite.m @@ -1,4 +1,4 @@ -function Phy2Neurosuite(basepath,clustering_path) +function Phy2Neurosuite(basepath,clustering_path,output_format) % Converts Phy output (NPY files) to Neurosuite files: fet, res, clu, spk files. % Based on the GPU enable filter from Kilosort and fractions from Brendon % Watson's code for saving Neurosuite files. @@ -208,11 +208,12 @@ function Phy2Neurosuite(basepath,clustering_path) DATA =zeros(NT, NchanTOT,Nbatch_buff,'int16'); - if isfield(ops,'fslow')&&ops.fslow0 + LSB = session.extracellular.leastSignificantBit; + end +elseif isempty(basename) + [~,basename,~] = fileparts(basepath); + disp(['Using basepath to determine the basename: ' basename]) + temp = dir('Kilosort_*'); + if ~isempty(temp) + clusteringpath = temp.name; % clusteringpath assumed from Kilosort + end end -if exist(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) & ~forceReload - load(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat'])) +clusteringpath_full = fullfile(basepath,clusteringpath); + +if exist(fullfile(clusteringpath_full,[basename,'.spikes.cellinfo.mat'])) & ~forceReload + load(fullfile(clusteringpath_full,[basename,'.spikes.cellinfo.mat'])) if isfield(spikes,'ts') && (~isfield(spikes,'processinginfo') || (isfield(spikes,'processinginfo') && spikes.processinginfo.version < 3 && strcmp(spikes.processinginfo.function,'loadSpikes') )) forceReload = true; disp('spikes.mat structure not up to date. Reloading spikes.') @@ -94,24 +107,26 @@ % Loading spikes if forceReload + % Loading session info + if buzcode + xml = bz_getSessionInfo(basepath, 'noPrompts', true); + xml.SampleRate = xml.rates.wideband; + else + if ~exist('LoadXml.m','file') || ~exist('xmltools.m','file') + error('''LoadXml.m'' and ''xmltools.m'' is not in your path and is required to load the xml file. If you have buzcode installed, please set ''buzcode'' to true in the input parameters.') + else + xml = LoadXml(fullfile(clusteringpath_full,[basename, '.xml'])); + end + end switch lower(clusteringFormat) % Loading klustakwik case {'klustakwik', 'neurosuite'} - % Loading session info - if buzcode - xml = bz_getSessionInfo(basepath, 'noPrompts', true); - xml.SampleRate = xml.rates.wideband; - elseif exist(fullfile(clusteringPath,[baseName, '.xml']),'file') - xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); - else - error(['xml file does not exist: ', fullfile(clusteringPath,[baseName, '.xml'])]) - end disp('loadSpikes: Loading Klustakwik data') unit_nb = 0; shanks_new = []; if isnan(shanks) - fileList = dir(fullfile(clusteringPath,[baseName,'.res.*'])); + fileList = dir(fullfile(clusteringpath_full,[basename,'.res.*'])); fileList = {fileList.name}; for i = 1:length(fileList) temp = strsplit(fileList{i},'.res.'); @@ -122,10 +137,10 @@ for shank = shanks disp(['Loading shank #' num2str(shank) '/' num2str(length(shanks)) ]) if ~raw_clusters - cluster_index = load(fullfile(clusteringPath, [baseName '.clu.' num2str(shank)])); - time_stamps = load(fullfile(clusteringPath,[baseName '.res.' num2str(shank)])); + cluster_index = load(fullfile(clusteringpath_full, [basename '.clu.' num2str(shank)])); + time_stamps = load(fullfile(clusteringpath_full,[basename '.res.' num2str(shank)])); if getWaveforms & useNeurosuiteWaveforms - fname = fullfile(clusteringPath,[baseName '.spk.' num2str(shank)]); + fname = fullfile(clusteringpath_full,[basename '.spk.' num2str(shank)]); f = fopen(fname,'r'); waveforms = LSB * double(fread(f,'int16')); samples = size(waveforms,1)/size(time_stamps,1); @@ -133,8 +148,8 @@ waveforms = reshape(waveforms, [electrodes,samples/electrodes,length(waveforms)/samples]); end else - cluster_index = load(fullfile(clusteringPath, 'OriginalClus', [baseName '.clu.' num2str(shank)])); - time_stamps = load(fullfile(clusteringPath, 'OriginalClus', [baseName '.res.' num2str(shank)])); + cluster_index = load(fullfile(clusteringpath_full, 'OriginalClus', [basename '.clu.' num2str(shank)])); + time_stamps = load(fullfile(clusteringpath_full, 'OriginalClus', [basename '.res.' num2str(shank)])); end cluster_index = cluster_index(2:end); nb_clusters = unique(cluster_index); @@ -164,35 +179,25 @@ end end if getWaveforms & ~useNeurosuiteWaveforms - spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); + spikes = GetWaveformsFromDat(spikes,xml,basepath,basename,LSB,session); end clear cluster_index time_stamps - % Loading phy + % Loading phy case 'phy' - % Loading session info - if buzcode - xml = bz_getSessionInfo(basepath, 'noPrompts', true); - xml.SampleRate = xml.rates.wideband; - elseif exist(fullfile(clusteringPath,[baseName, '.xml']),'file') - xml = LoadXml(fullfile(clusteringPath,[baseName, '.xml'])); - else - error(['xml file does not exist: ', fullfile(clusteringPath,[baseName, '.xml'])]) - end disp('loadSpikes: Loading Phy/Kilosort data') - - spike_cluster_index = readNPY(fullfile(clusteringPath, 'spike_clusters.npy')); - spike_times = readNPY(fullfile(clusteringPath, 'spike_times.npy')); - spike_amplitudes = readNPY(fullfile(clusteringPath, 'amplitudes.npy')); + spike_cluster_index = readNPY(fullfile(clusteringpath_full, 'spike_clusters.npy')); + spike_times = readNPY(fullfile(clusteringpath_full, 'spike_times.npy')); + spike_amplitudes = readNPY(fullfile(clusteringpath_full, 'amplitudes.npy')); spike_clusters = unique(spike_cluster_index); - filename1 = fullfile(clusteringPath,'cluster_group.tsv'); - filename2 = fullfile(clusteringPath,'cluster_groups.csv'); - if exist(fullfile(clusteringPath, 'cluster_ids.npy')) && exist(fullfile(clusteringPath, 'shanks.npy')) && exist(fullfile(clusteringPath, 'peak_channel.npy')) - cluster_ids = readNPY(fullfile(clusteringPath, 'cluster_ids.npy')); - unit_shanks = readNPY(fullfile(clusteringPath, 'shanks.npy')); - peak_channel = readNPY(fullfile(clusteringPath, 'peak_channel.npy'))+1; - if exist(fullfile(clusteringPath, 'rez.mat')) - load(fullfile(clusteringPath, 'rez.mat')) + filename1 = fullfile(clusteringpath_full,'cluster_group.tsv'); + filename2 = fullfile(clusteringpath_full,'cluster_groups.csv'); + if exist(fullfile(clusteringpath_full, 'cluster_ids.npy')) && exist(fullfile(clusteringpath_full, 'shanks.npy')) && exist(fullfile(clusteringpath_full, 'peak_channel.npy')) + cluster_ids = readNPY(fullfile(clusteringpath_full, 'cluster_ids.npy')); + unit_shanks = readNPY(fullfile(clusteringpath_full, 'shanks.npy')); + peak_channel = readNPY(fullfile(clusteringpath_full, 'peak_channel.npy'))+1; + if exist(fullfile(clusteringpath_full, 'rez.mat')) + load(fullfile(clusteringpath_full, 'rez.mat')) temp = find(rez.connected); peak_channel = temp(peak_channel); clear rez temp @@ -251,29 +256,26 @@ end if getWaveforms % gets waveforms from dat file - spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); + spikes = GetWaveformsFromDat(spikes,xml,basepath,basename,LSB,session); end % Loading klustaViewa - Kwik format (Klustasuite 0.3.0.beta4) - case {'klustaViewa','kwik'} + case 'klustaViewa' disp('loadSpikes: Loading KlustaViewa data') - if isnan(shanks) - error('Please provide the number of shanks for the session') - end shank_nb = 1; for shank = 1:shanks - spike_times = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/spikes/time_samples'])); - recording_nb = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/spikes/recording'])); - cluster_index = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/spikes/clusters/main'])); - waveforms = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwx']), ['/channel_groups/' num2str(shank-1) '/waveforms_filtered'])); + spike_times = double(hdf5read([clusteringpath_full, basename, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/time_samples'])); + recording_nb = double(hdf5read([clusteringpath_full, basename, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/recording'])); + cluster_index = double(hdf5read([clusteringpath_full, basename, '.kwik'], ['/channel_groups/' num2str(shank-1) '/spikes/clusters/main'])); + waveforms = double(hdf5read([clusteringpath_full, basename, '.kwx'], ['/channel_groups/' num2str(shank-1) '/waveforms_filtered'])); clusters = unique(cluster_index); for i = 1:length(clusters(:)) - cluster_type = double(hdf5read(fullfile(clusteringPath, [baseName, '.kwik']), ['/channel_groups/' num2str(shank-1) '/clusters/main/' num2str(clusters(i)),'/'],'cluster_group')); + cluster_type = double(hdf5read([clusteringpath_full, basename, '.kwik'], ['/channel_groups/' num2str(shank-1) '/clusters/main/' num2str(clusters(i)),'/'],'cluster_group')); if cluster_type == 2 indexes{shank_nb} = shank_nb*ones(sum(cluster_index == clusters(i)),1); spikes.UID(shank_nb) = shank_nb; spikes.ts{shank_nb} = spike_times(cluster_index == clusters(i))+recording_nb(cluster_index == clusters(i))*40*40000; - spikes.times{shank_nb} = spikes.ts{shank_nb}/40000; + spikes.times{shank_nb} = spikes.ts{j}/xml.SampleRate; spikes.total(shank_nb) = sum(cluster_index == clusters(i)); spikes.shankID(shank_nb) = shank-1; spikes.cluID(shank_nb) = clusters(i); @@ -284,11 +286,11 @@ end end -% if getWaveforms % get waveforms -% spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session); -% end + if getWaveforms % get waveforms + spikes = GetWaveformsFromDat(spikes,xml,basepath,basename,LSB,session); + end end - spikes.sessionName = baseName; + spikes.sessionName = basename; % Generate spindices matrics spikes.numcells = length(spikes.UID); @@ -310,22 +312,28 @@ spikes.processinginfo.params.shanks = shanks; spikes.processinginfo.params.raw_clusters = raw_clusters; spikes.processinginfo.params.getWaveforms = getWaveforms; - spikes.processinginfo.params.baseName = baseName; + spikes.processinginfo.params.basename = basename; spikes.processinginfo.params.clusteringFormat = clusteringFormat; - spikes.processinginfo.params.clusteringPath = clusteringPath; + spikes.processinginfo.params.clusteringpath = clusteringpath; spikes.processinginfo.params.basepath = basepath; spikes.processinginfo.params.useNeurosuiteWaveforms = useNeurosuiteWaveforms; + try + spikes.processinginfo.username = char(java.lang.System.getProperty('user.name')); + spikes.processinginfo.hostname = char(java.net.InetAddress.getLocalHost.getHostName); + catch + disp('Failed to retrieve system info.') + end % Saving output to a buzcode compatible spikes file. if saveMat disp('loadSpikes: Saving spikes') - save(fullfile(clusteringPath,[baseName,'.spikes.cellinfo.mat']),'spikes') + save(fullfile(clusteringpath,[basename,'.spikes.cellinfo.mat']),'spikes') end end end -function spikes = GetWaveformsFromDat(spikes,xml,basepath,baseName,LSB,session) +function spikes = GetWaveformsFromDat(spikes,xml,basepath,basename,LSB,session) % Requires a neurosuite xml structure. % Bad channels must be deselected in the spike groups, or skipped beforehand timerVal = tic; @@ -360,15 +368,15 @@ [b1, a1] = butter(3, filtFreq/xml.SampleRate*2, 'bandpass'); -f = waitbar(0,['Getting waveforms from dat file'],'Name',['Processing ' baseName]); +f = waitbar(0,['Getting waveforms from dat file'],'Name',['Processing ' basename]); if showWaveforms - fig1 = figure('Name', ['Getting waveforms for ' baseName],'NumberTitle', 'off'); + fig1 = figure('Name', ['Getting waveforms for ' basename],'NumberTitle', 'off','position',[100,100,1000,800]); end wfWin = round((wfWin_sec * xml.SampleRate)/2); t1 = toc(timerVal); -s = dir(fullfile(basepath,[baseName '.dat'])); +s = dir(fullfile(basepath,[basename '.dat'])); duration = s.bytes/(2*xml.nChannels*xml.SampleRate); -m = memmapfile(fullfile(basepath,[baseName '.dat']),'Format','int16','writable',false); +m = memmapfile(fullfile(basepath,[basename '.dat']),'Format','int16','writable',false); DATA = m.Data; for ii = 1 : size(spikes.times,2) @@ -439,16 +447,14 @@ if ishandle(fig1) figure(fig1) subplot(2,2,1), hold off - plot(wfF2), hold on, plot(wfF2(:,idx),'k','linewidth',2), title('Filt waveform across channels'), xlabel('Samples'), hold off - + plot(wfF2), hold on, plot(wfF2(:,idx),'k','linewidth',2), title('Filtered waveforms across channels'), xlabel('Samples'), ylabel('uV'),hold off subplot(2,2,2), hold off, - plot(wfF), title('Peak channel waveforms'), xlabel('Samples') - + plot(wfF), title(['Peak channel waveforms (maxWaveformCh1=',num2str(spikes.maxWaveformCh1(ii)),')']), xlabel('Samples'), ylabel('uV') subplot(2,2,3), hold on, - plot(spikes.timeWaveform{ii},spikes.rawWaveform{ii}), title('Raw waveform'), xlabel('Time (ms)') + plot(spikes.timeWaveform{ii},spikes.rawWaveform{ii}), title(['Raw waveform (',num2str(ii),'/',num2str(size(spikes.times,2)),')']), xlabel('Time (ms)'), ylabel('uV') xlim([-0.8,0.8]) subplot(2,2,4), hold on, - plot(spikes.timeWaveform{ii},spikes.filtWaveform{ii}), title('Filtered waveform'), xlabel('Time (ms)') + plot(spikes.timeWaveform{ii},spikes.filtWaveform{ii}), title('Filtered waveform'), xlabel('Time (ms)'), ylabel('uV') xlim([-0.8,0.8]) end clear wf wfF wf2 wfF2 @@ -466,7 +472,7 @@ waitbar(ii/size(spikes.times,2),f,['Waveform extraction complete ',num2str(ii),'/',num2str(size(spikes.times,2)),'. ', num2str(round(toc(timerVal)/60)) ' minutes total']); disp(['Waveform extraction complete. Total duration: ' num2str(round(toc(timerVal)/60)),' minutes']) if ishandle(fig1) - set(fig1,'Name',['Waveform extraction complete for ' baseName]) + set(fig1,'Name',['Waveform extraction complete for ' basename]) end % close(f) end