In [1]:
import numpy as np
import visionloader as vl
import os

In [2]:
def compute_duplicates(vstim_data, allowed_types, MIN_CORR=0.95):
    duplicates = set()
    cellids = vstim_data.get_cell_ids()
    for cell in cellids:
        if cell in duplicates:
            continue

        cell_ei = vstim_data.get_ei_for_cell(cell).ei
        cell_ei_error = vstim_data.get_ei_for_cell(cell).ei_error
        cell_ei_power = np.sum(cell_ei**2,axis=1)
        celltype = vstim_data.get_cell_type_for_cell(cell).lower()

        allowed_cell = False
        for allowed_type in allowed_types:
            if allowed_type.lower() in celltype:
                allowed_cell = True
                break
                # after breaking, allowed_type will be set to the type that was found

        if allowed_cell:
            print(cell, celltype, allowed_type)
            for other_cell in cellids:
                if cell == other_cell or other_cell in duplicates:
                    continue
                other_celltype = vstim_data.get_cell_type_for_cell(other_cell).lower()
                allowed_othercell = False
                for allowed_othertype in allowed_types:
                    if allowed_othertype.lower() in other_celltype:
                        allowed_othercell = True
                        break
                
                if allowed_type == allowed_othertype and allowed_othercell:
                    print(cell, celltype, other_cell, other_celltype)   
                    other_cell_ei = vstim_data.get_ei_for_cell(other_cell).ei
                    other_cell_ei_power = np.sum(other_cell_ei**2,axis=1)
                    # Compute the correlation and figure out if we have duplicates: take the larger number of spikes.
                    corr = np.corrcoef(cell_ei_power,other_cell_ei_power)[0,1]
                    if corr >= MIN_CORR:
                        
                        n_spikes_cell = vstim_data.get_spike_times_for_cell(cell).shape[0]
                        n_spikes_other_cell = vstim_data.get_spike_times_for_cell(other_cell).shape[0]
                        # Take the larger number of spikes, unless the one with fewer is a light responsive type.
                        if n_spikes_cell > n_spikes_other_cell:
                            print(f'DUPLICATE FOUND: {cell} and {other_cell} with corr {corr}, choosing {other_cell} as duplicate')
                            duplicates.add(other_cell)
                        else:
                            print(f'DUPLICATE FOUND: {cell} and {other_cell} with corr {corr}, choosing {cell} as duplicate')
                            duplicates.add(cell)

    for cell in set(cellids).difference(duplicates):
        cell_ei_error = vstim_data.get_ei_for_cell(cell).ei_error[vstim_data.channel_noise != 0]
        
        if np.any(cell_ei_error == 0):
            duplicates.add(cell)     

    return duplicates, set(cellids).difference(duplicates)

In [3]:
WNOISE_ANALYSIS_BASE = "/Volumes/Analysis"
dataset = "2023-09-19-0"
wnoise = "kilosort_data000/data000"

vcd = vl.load_vision_data(os.path.join(WNOISE_ANALYSIS_BASE, dataset, wnoise),
                          os.path.basename(wnoise),
                          include_ei=True,
                          include_neurons=True,
                          include_params=True,
                          include_sta=True,
                          include_noise=True)

In [4]:
allowed_types = ['on', 'off', 'weak']
duplicates, nonduplicates = compute_duplicates(vcd, allowed_types)

199 off off
199 off 548 off
199 off 379 off
199 off 549 off
199 off 551 off
199 off 201 off
199 off 555 off
199 off 4 off
199 off 203 off
199 off 557 off
199 off 558 off
199 off 205 off
199 off 390 off
199 off 391 off
199 off 392 off
199 off 561 off
199 off 562 off
199 off 209 off
199 off 210 off
199 off 15 off
199 off 564 off
199 off 398 off
199 off 566 off
199 off 401 off
199 off 402 off
199 off 20 off
199 off 571 off
199 off 217 off
199 off 22 off
199 off 576 off
199 off 577 off
199 off 220 off
199 off 578 off
199 off 29 off
199 off 33 off
199 off 37 off
199 off 412 off
199 off 582 off
199 off 40 off
199 off 415 off
199 off 49 off
199 off 585 off
199 off 586 off
199 off 588 off
199 off 53 off
199 off 54 off
199 off 590 off
199 off 591 off
199 off 593 off
199 off 426 off
199 off 232 off
199 off 427 off
199 off 429 off
199 off 56 off
199 off 431 off
199 off 598 off
199 off 600 off
199 off 57 off
199 off 58 off
199 off 434 off
199 off 61 off
199 off 63 off
199 off 242 off
199 off 66 of

In [5]:
with open(os.path.join(WNOISE_ANALYSIS_BASE, dataset, wnoise, f'{os.path.basename(wnoise)}.classification_deduped.txt'), 'w') as f:
    for cell in duplicates:
        f.write(f'{cell}  All/duplicates/\n')
    for cell in nonduplicates:
        f.write(f'{cell}  All/{vcd.get_cell_type_for_cell(cell)}/\n')