some functions i wrote

In [None]:
from spyglass.utils.nwb_helper_fn import get_nwb_file
from spyglass.common import Nwbfile
from spyglass.utils import logger
import spyglass.spikesorting.v1 as sgs
from spyglass.spikesorting.v1 import SortGroup


def get_sort_groups_from_column(
    nwb_file_name: str,
    column: str,
    groups: list[list],
    remove_bad_channels: bool = True,
    omit_unitrode: bool = True,
):
    """
    Create custom SortGroups from a Berke Lab NWB file's electrode table based on a chosen column.

    Potential columns include 'intan_channel_number' (0-based, with values such as 191, 190, etc.),
    and 'electrode_name' (with values such as 'S01E01', 'S01E02', etc.)

    Optionally use the electrode_id (index) directly by passing "index" or "electrode_id"

    Parameters
    ----------
    nwb_file_name : str
        Name of the NWB file.
    column : str
        Column in the electrode table to group by (e.g., "channel_number").
    groups : list of lists
        Each sublist specifies values in `column` to include in one sort group.
    remove_bad_channels : bool
        If True, electrodes with bad_channel != 0 are removed.
    omit_unitrode : bool
        If True, groups with only one electrode are skipped.

    Returns
    -------
    sg_keys : list of dict
        One dict per sort group (contains nwb_file_name, sort_group_id, sort_reference_electrode_id)
    sge_keys : list of dict
        One dict per electrode assignment to a sort group
    """

    # Get electrode table from nwbfile
    nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
    nwbf = get_nwb_file(nwb_file_abspath)
    electrodes_df = nwbf.electrodes.to_dataframe()

    sg_keys, sge_keys = [], []

    # Option to select directly based on index
    use_index = column in ("index", "id", "idx", "electrode_id")

    for i, group_vals in enumerate(groups):
        if use_index:
            # Match directly against the df index (electrode_id)
            subset = electrodes_df.loc[electrodes_df.index.isin(group_vals)]
        else:
            # Match against a column in the electrode table
            subset = electrodes_df[electrodes_df[column].isin(group_vals)]

        # Optionally remove bad channels
        if remove_bad_channels:
            bad_subset = subset[subset["bad_channel"] == 1]
            if not bad_subset.empty:
                logger.info(
                    f"Removing bad channels from group {i}: "
                    f"{bad_subset.index.tolist() if use_index else bad_subset[column].tolist()}"
                )
            subset = subset[subset["bad_channel"] == 0]

        if subset.empty:
            logger.warning(f"Omitting group {i} (all bad channels or no matches).")
            continue

        # Optionally skip unitrodes
        if omit_unitrode and len(subset) == 1:
            logger.warning(f"Omitting group {i} (unitrode).")
            continue

        # Log which electrodes are in this sort group
        logger.info(
            f"Adding group {i}: electrode_ids={subset.index.tolist()}"
            + ("" if use_index else f", {column}={subset[column].tolist()}")
        )

        # Build sort group key
        sg_key = dict(
            nwb_file_name=nwb_file_name,
            sort_group_id=i,  # enumerate sort groups starting from 0
            sort_reference_electrode_id=-1,  # we always use -1 for reference electrode
        )
        sg_keys.append(sg_key)

        # Build electrode entries using electrode df index as electrode_id
        for eid, row in subset.iterrows():
            sge_keys.append(
                dict(
                    nwb_file_name=nwb_file_name,
                    sort_group_id=i,
                    electrode_id=eid,
                    electrode_group_name=row["group_name"],
                )
            )

    return sg_keys, sge_keys


def set_sort_group_by_column(
    nwb_file_name: str,
    column: str,
    groups: list[list],
    remove_bad_channels: bool = True,
    omit_unitrode: bool = True,
):
    """Divides electrodes into groups based on a chosen column in the elctrodes table.

    Parameters
    ----------
    nwb_file_name : str
        Name of the NWB file.
    column : str
        Column in the electrode table to group by (e.g., "channel_number").
    groups : list of lists
        Each sublist specifies values in `column` to include in one sort group.
    remove_bad_channels : bool
        If True, electrodes with bad_channel != 0 are removed.
    omit_unitrode : bool
        If True, groups with only one electrode are skipped.
    """
    existing_entries = SortGroup & {"nwb_file_name": nwb_file_name}
    if existing_entries:
        # delete any current groups
        (SortGroup & {"nwb_file_name": nwb_file_name}).delete()

    sg_keys, sge_keys = get_sort_groups_from_column(
        nwb_file_name=nwb_file_name,
        column=column,
        groups=groups,
        remove_bad_channels=remove_bad_channels,
        omit_unitrode=omit_unitrode,
    )
    SortGroup.insert(sg_keys, skip_duplicates=True)
    SortGroup.SortGroupElectrode().insert(sge_keys, skip_duplicates=True)


channel_groups = [
    [70, 69, 68, 67, 66, 65],
    [82, 81],
    [94, 93, 92, 91, 90, 89, 88, 87, 86],
    [136, 135, 134],
    [10],
]

nwb_file_name = "IM-1875_darling_20250720_.nwb"

set_sort_group_by_column(
    nwb_file_name=nwb_file_name, column="intan_channel_number", groups=channel_groups
)