In [1]:
import numpy as np
import visionloader as vl
import os

In [2]:
def compute_duplicates(vstim_data, allowed_types, MIN_CORR=0.95):
    duplicates = set()
    cellids = vstim_data.get_cell_ids()
    for cell in cellids:
        if cell in duplicates:
            continue

        cell_ei = vstim_data.get_ei_for_cell(cell).ei
        cell_ei_error = vstim_data.get_ei_for_cell(cell).ei_error
        cell_ei_power = np.sum(cell_ei**2,axis=1)
        celltype = vstim_data.get_cell_type_for_cell(cell).lower()

        allowed_cell = False
        for allowed_type in allowed_types:
            if allowed_type.lower() in celltype:
                allowed_cell = True
                break
                # after breaking, allowed_type will be set to the type that was found

        if allowed_cell:
            print(cell, celltype, allowed_type)
            for other_cell in cellids:
                if cell == other_cell or other_cell in duplicates:
                    continue
                other_celltype = vstim_data.get_cell_type_for_cell(other_cell).lower()
                allowed_othercell = False
                for allowed_othertype in allowed_types:
                    if allowed_othertype.lower() in other_celltype:
                        allowed_othercell = True
                        break
                
                if allowed_type == allowed_othertype and allowed_othercell:
                    print(cell, celltype, other_cell, other_celltype)   
                    other_cell_ei = vstim_data.get_ei_for_cell(other_cell).ei
                    other_cell_ei_power = np.sum(other_cell_ei**2,axis=1)
                    # Compute the correlation and figure out if we have duplicates: take the larger number of spikes.
                    corr = np.corrcoef(cell_ei_power,other_cell_ei_power)[0,1]
                    if corr >= MIN_CORR:
                        
                        n_spikes_cell = vstim_data.get_spike_times_for_cell(cell).shape[0]
                        n_spikes_other_cell = vstim_data.get_spike_times_for_cell(other_cell).shape[0]
                        # Take the larger number of spikes, unless the one with fewer is a light responsive type.
                        if n_spikes_cell > n_spikes_other_cell:
                            print(f'DUPLICATE FOUND: {cell} and {other_cell} with corr {corr}, choosing {other_cell} as duplicate')
                            duplicates.add(other_cell)
                        else:
                            print(f'DUPLICATE FOUND: {cell} and {other_cell} with corr {corr}, choosing {cell} as duplicate')
                            duplicates.add(cell)

    # for cell in set(cellids).difference(duplicates):
    #     cell_ei_error = vstim_data.get_ei_for_cell(cell).ei_error[vstim_data.channel_noise != 0]
        
    #     if np.any(cell_ei_error == 0):
    #         duplicates.add(cell)     

    return duplicates, set(cellids).difference(duplicates)

In [3]:
WNOISE_ANALYSIS_BASE = "/Volumes/Acquisition/Analysis"
dataset = "2023-10-06-0"
wnoise = "data001"

vcd = vl.load_vision_data(os.path.join(WNOISE_ANALYSIS_BASE, dataset, wnoise),
                          os.path.basename(wnoise),
                          include_ei=True,
                          include_neurons=True,
                          include_params=True,
                          include_sta=True,
                          include_noise=True)

In [4]:
allowed_types = ['on', 'off', 'weak']
duplicates, nonduplicates = compute_duplicates(vcd, allowed_types)

3786 weak collapsed weak
3786 weak collapsed 18 weak
3786 weak collapsed 1727 weak
3786 weak collapsed 3812 weak collapsed
3786 weak collapsed 1741 weak
3786 weak collapsed 6052 weak collapsed
3786 weak collapsed 31 weak
3786 weak collapsed 1756 weak
3786 weak collapsed 3826 weak collapsed
3786 weak collapsed 78 weak
3786 weak collapsed 1760 weak collapsed
3786 weak collapsed 1761 weak collapsed
3786 weak collapsed 3858 weak
3786 weak collapsed 6079 weak
3786 weak collapsed 1763 weak collapsed
3786 weak collapsed 136 weak
3786 weak collapsed 1773 weak collapsed
3786 weak collapsed 3862 weak collapsed
3786 weak collapsed 187 weak
3786 weak collapsed 6093 weak
3786 weak collapsed 1774 weak collapsed
3786 weak collapsed 241 weak
3786 weak collapsed 1786 weak
3786 weak collapsed 3931 weak collapsed
3786 weak collapsed 316 weak
3786 weak collapsed 3934 weak
3786 weak collapsed 6121 weak collapsed
3786 weak collapsed 1806 weak collapsed
3786 weak collapsed 6124 weak collapsed
3786 weak colla

In [7]:
len(duplicates)

71

In [None]:
with open(os.path.join(WNOISE_ANALYSIS_BASE, dataset, wnoise, f'{os.path.basename(wnoise)}.classification_deduped.txt'), 'w') as f:
    for cell in duplicates:
        f.write(f'{cell}  All/duplicates/\n')
    for cell in nonduplicates:
        f.write(f'{cell}  All/{vcd.get_cell_type_for_cell(cell)}/\n')