In [1]:
import numpy as np
import visionloader as vl
import os

In [2]:
def compute_duplicates(vstim_data, allowed_types, MIN_CORR=0.95):
    duplicates = set()
    cellids = vstim_data.get_cell_ids()
    for cell in cellids:
        if cell in duplicates:
            continue

        cell_ei = vstim_data.get_ei_for_cell(cell).ei
        cell_ei_error = vstim_data.get_ei_for_cell(cell).ei_error
        cell_ei_power = np.sum(cell_ei**2,axis=1)
        celltype = vstim_data.get_cell_type_for_cell(cell).lower()

        allowed_cell = False
        for allowed_type in allowed_types:
            if allowed_type.lower() in celltype:
                allowed_cell = True
                break
                # after breaking, allowed_type will be set to the type that was found

        if allowed_cell:
            print(cell, celltype, allowed_type)
            for other_cell in cellids:
                if cell == other_cell or other_cell in duplicates:
                    continue
                other_celltype = vstim_data.get_cell_type_for_cell(other_cell).lower()
                allowed_othercell = False
                for allowed_othertype in allowed_types:
                    if allowed_othertype.lower() in other_celltype:
                        allowed_othercell = True
                        break
                
                if allowed_type == allowed_othertype and allowed_othercell:
                    print(cell, celltype, other_cell, other_celltype)   
                    other_cell_ei = vstim_data.get_ei_for_cell(other_cell).ei
                    other_cell_ei_power = np.sum(other_cell_ei**2,axis=1)
                    # Compute the correlation and figure out if we have duplicates: take the larger number of spikes.
                    corr = np.corrcoef(cell_ei_power,other_cell_ei_power)[0,1]
                    if corr >= MIN_CORR:
                        
                        n_spikes_cell = vstim_data.get_spike_times_for_cell(cell).shape[0]
                        n_spikes_other_cell = vstim_data.get_spike_times_for_cell(other_cell).shape[0]
                        # Take the larger number of spikes, unless the one with fewer is a light responsive type.
                        if n_spikes_cell > n_spikes_other_cell:
                            print(f'DUPLICATE FOUND: {cell} and {other_cell} with corr {corr}, choosing {other_cell} as duplicate')
                            duplicates.add(other_cell)
                        else:
                            print(f'DUPLICATE FOUND: {cell} and {other_cell} with corr {corr}, choosing {cell} as duplicate')
                            duplicates.add(cell)

    for cell in set(cellids).difference(duplicates):
        cell_ei_error = vstim_data.get_ei_for_cell(cell).ei_error[vstim_data.channel_noise != 0]
        
        if np.any(cell_ei_error == 0):
            duplicates.add(cell)     

    return duplicates, set(cellids).difference(duplicates)

In [3]:
# TODO: update path, piece number, datarun

WNOISE_ANALYSIS_BASE = "/Volumes/Acquisition/Analysis"
dataset = "2023-10-30-0"
wnoise = "data000"

vcd = vl.load_vision_data(os.path.join(WNOISE_ANALYSIS_BASE, dataset, wnoise),
                          os.path.basename(wnoise),
                          include_ei=True,
                          include_neurons=True,
                          include_params=True,
                          include_sta=True,
                          include_noise=True)

In [4]:
allowed_types = ['parasol', 'midget', 'crap']   # For monkey: ['parasol', 'midget', 'crap']
duplicates, nonduplicates = compute_duplicates(vcd, allowed_types)

2311 off nc2 off
2311 off nc2 2314 off nc1
2311 off nc2 2371 off nc2
2311 off nc2 227 off nc2
2311 off nc2 241 off nc2
2311 off nc2 2581 off nc2
2311 off nc2 304 off nc1
2311 off nc2 347 off nc1
2311 off nc2 4398 off nc2
2311 off nc2 4441 off nc1
2311 off nc2 4606 off nc2
2311 off nc2 587 off nc1
2311 off nc2 602 off nc2
2311 off nc2 6631 off nc1
2311 off nc2 4757 off nc2
2311 off nc2 4831 off nc2
2311 off nc2 811 off nc2
2311 off nc2 5152 off nc1
2311 off nc2 907 off nc1
2311 off nc2 991 off nc2
2311 off nc2 5236 off nc1
2311 off nc2 1021 off nc2
2311 off nc2 2986 off nc2
2311 off nc2 5268 off nc2
2311 off nc2 3016 off nc1
2311 off nc2 5341 off nc1
2311 off nc2 3122 off nc1
2311 off nc2 1128 off nc1
2311 off nc2 5645 off nc1
2311 off nc2 5656 off nc2
2311 off nc2 5701 off nc1
2311 off nc2 3392 off nc2
2311 off nc2 5776 off nc2
2311 off nc2 3395 off nc1
2311 off nc2 6936 off nc1
2311 off nc2 3451 off nc2
2311 off nc2 6961 off nc1
2311 off nc2 5795 off nc1
2311 off nc2 1306 off nc2
2311

In [5]:
len(duplicates)

0

In [6]:
with open(os.path.join(WNOISE_ANALYSIS_BASE, dataset, wnoise, 'classification_deduped.txt'), 'w') as f:
    for cell in duplicates:
        f.write(f'{cell}  All/duplicates/\n')
    for cell in nonduplicates:
        f.write(f'{cell}  All/{vcd.get_cell_type_for_cell(cell)}/\n')