In [86]:
import numpy as np

In [None]:
def sub_nanpad(x):
    # Define your sub_nanpad function here if it has additional logic
    # For now, I'm assuming it just flattens the list and pads with NaN
    return np.array(x, dtype=float).flatten()  # Convert to numpy array and flatten

In [1]:
def ct_trimdata(in_data, param=None):
    
   # Manage input type: dict or numpy array
    if isinstance(in_data, dict):
        inpass = {k: v for k, v in in_data.items() if not (isinstance(v, (np.ndarray, list)))}
        in_data = [v for k, v in in_data.items() if isinstance(v, (np.ndarray, list))]
    if isinstance(in_data, np.ndarray):
        in_data = [in_data]
    nchan = len(in_data)
    
    # Manage operational parameters
    p = {
        'gapmax': 10,
        'dlengthmin': 2,
        'startonly': True,
        'packascell': False
    }
    if param:
        for k, v in param.items():
            p[k.lower()] = v
    
    # Reconstitute as NaN padded array
    in_data = [sub_nanpad(ch) if isinstance(ch, list) else ch for ch in in_data]
    
    # Define indices of 'good' data (non-NaN)
    goodi = [~np.isnan(ch) for ch in in_data]
    
    # Require good indices in all channels
    goodi = np.all(np.stack(goodi, axis=-1), axis=-1)
    nCells, nTime = goodi.shape
    
    # Indicate data range
    drng = np.zeros((nCells, nTime), dtype=bool)
    dst = np.logical_and(goodi, ~np.hstack([np.zeros((nCells, 1), dtype=bool), goodi[:, :-1]]))
    dnd = np.logical_and(goodi, ~np.hstack([goodi[:, 1:], np.zeros((nCells, 1), dtype=bool)]))
    
    for s in range(nCells):
        dnd_indices = np.where(dnd[s, :])[0]
        dst_indices = np.where(dst[s, :])[0]
        
        if p['startonly']:
            if dnd_indices.size != 0:
                # If only starts are considered, set the range from start to the last 'end' as True
                drng[s, :dnd_indices[-1] + 1] = True
        else:
            if dnd_indices.size != 0 and dst_indices.size != 0:
                # Otherwise, set the range from the first 'start' to the last 'end' as True
                drng[s, dst_indices[0]:dnd_indices[-1] + 1] = True
    ########################################
    gst = [None] * nCells
    gnd = [None] * nCells
    for s in range(nCells):
        gaps = np.logical_and(~goodi[s, :], drng[s, :])
        if np.any(gaps):
            # Find indices marking beginning and end of each gap
            gst[s] = np.where(gaps & np.hstack((False, ~gaps[:-1])))[0]
            gnd[s] = np.where(gaps & np.hstack((~gaps[1:], False)))[0]
            # Check for starting gaps, if startonly
            if p['startonly'] and len(gnd[s]) > len(gst[s]):
                if gnd[s][0] > p['gapmax']:
                    drng[s, :] = False
                    continue
                else:
                    drng[s, :gnd[s][0] + 1] = True
                    for sc in range(nchan):
                        in_data[sc][s, :gnd[s][0] + 1] = in_data[sc][s, gnd[s][0] + 1]
                    gnd[s] = gnd[s][1:]
    
            # Proceed with gap evaluation
            gapsz = gnd[s] - gst[s] + 1
            if np.any(gapsz > p['gapmax']):
                biggap = gapsz > p['gapmax']
                bgnd = np.hstack((np.where(drng[s, :])[0][0] - 1, gnd[s][biggap]))
                bgst = np.hstack((gst[s][biggap], np.where(drng[s, :])[0][-1] + 1))
    
                if p['startonly']:
                    drng[s, bgst[0]:] = False
                else:
                    mx_idx = np.argmax(bgst - bgnd)
                    drng[s, :] = False
                    drng[s, bgnd[mx_idx] + 1:bgst[mx_idx]] = True
    
                gst[s] = [g for g in gst[s] if drng[s, g]]
                gnd[s] = [g for g in gnd[s] if drng[s, g]]
                
            # Enforce minimum data range
            if np.count_nonzero(drng[s, :]) < p['dlengthmin']:
                drng[s, :] = False
                continue
    
            # Fill gaps with linear interpolation
            for ss in range(min(len(gnd[s]), len(gst[s]))):
                for sc in range(nchan):
                    in_data[sc][s, gst[s][ss] - 1:gnd[s][ss] + 2] = \
                        np.linspace(in_data[sc][s, gst[s][ss] - 1], 
                                    in_data[sc][s, gnd[s][ss] + 1], 
                                    gnd[s][ss] - gst[s][ss] + 3)
        elif np.count_nonzero(drng[s, :]) < p['dlengthmin']:
            drng[s, :] = False
            continue
    ##########################################
    out = {"data": None, "cellindex": None}
        
    if p["packascell"]:
        # Find start points in each track (n is TRUE & n-1 is FALSE)
        stord, dst = np.where(drng & ~np.hstack((np.zeros((nCells, 1), dtype=bool), drng[:, :-1])))
        # Find end points in each track (n is TRUE & n+1 is FALSE)
        ndord, dnd = np.where(drng & ~np.hstack((drng[:, 1:], np.zeros((nCells, 1), dtype=bool))))
    
        # Sort tracks by original row number
        hasdata, sti = np.sort(stord, kind="mergesort")
        _, ndi = np.sort(ndord, kind="mergesort")
    
        # Apply sorting to time starts and ends
        timei = np.array([dst[sti], dnd[ndi]]).T
        
        # Extract target data (excluding data to be trimmed)
        in_data = [
            [x[hasdata][i][timei[i, 0]:timei[i, 1]+1] for i in range(len(hasdata))]
            for x in in_data
        ]
    
        out["time"] = timei
        out["cellindex"] = np.where(hasdata)[0]
    else:
        hasdata = np.any(drng, axis=1)
        out["cellindex"] = np.where(hasdata)[0]
    
        for sc in range(len(in_data)):
            in_data[sc][~drng] = np.nan
            in_data[sc] = in_data[sc][hasdata, :]
    
    out["data"] = in_data
    return out

# Usage
out = ct_trimdata(pd, p)
out

# Raw
in_data = pd
param = p 
#Manage input type: dict or numpy array
if isinstance(in_data, dict):
    inpass = {k: v for k, v in in_data.items() if not (isinstance(v, (np.ndarray, list)))}
    in_data = [v for k, v in in_data.items() if isinstance(v, (np.ndarray, list))]
if isinstance(in_data, np.ndarray):
    in_data = [in_data]
nchan = len(in_data)

#Manage operational parameters
p = {
    'gapmax': 10,
    'dlengthmin': 2,
    'startonly': True,
    'packascell': False
}
if param:
    for k, v in param.items():
        p[k.lower()] = v

#Reconstitute as NaN padded array
in_data = [sub_nanpad(ch) if isinstance(ch, list) else ch for ch in in_data]

#Define indices of 'good' data (non-NaN)
goodi = [~np.isnan(ch) for ch in in_data]

#Require good indices in all channels
goodi = np.all(np.stack(goodi, axis=-1), axis=-1)
nCells, nTime = goodi.shape

#Indicate data range
drng = np.zeros((nCells, nTime), dtype=bool)
dst = np.logical_and(goodi, ~np.hstack([np.zeros((nCells, 1), dtype=bool), goodi[:, :-1]]))
dnd = np.logical_and(goodi, ~np.hstack([goodi[:, 1:], np.zeros((nCells, 1), dtype=bool)]))

for s in range(nCells):
    dnd_indices = np.where(dnd[s, :])[0]
    dst_indices = np.where(dst[s, :])[0]
    
    if p['startonly']:
        if dnd_indices.size != 0:
            # If only starts are considered, set the range from start to the last 'end' as True
            drng[s, :dnd_indices[-1] + 1] = True
    else:
        if dnd_indices.size != 0 and dst_indices.size != 0:
            # Otherwise, set the range from the first 'start' to the last 'end' as True
            drng[s, dst_indices[0]:dnd_indices[-1] + 1] = True
########################################
gst = [None] * nCells
gnd = [None] * nCells
for s in range(nCells):
    gaps = np.logical_and(~goodi[s, :], drng[s, :])
    if np.any(gaps):
        # Find indices marking beginning and end of each gap
        gst[s] = np.where(gaps & np.hstack((False, ~gaps[:-1])))[0]
        gnd[s] = np.where(gaps & np.hstack((~gaps[1:], False)))[0]
        # Check for starting gaps, if startonly
        if p['startonly'] and len(gnd[s]) > len(gst[s]):
            if gnd[s][0] > p['gapmax']:
                drng[s, :] = False
                continue
            else:
                drng[s, :gnd[s][0] + 1] = True
                for sc in range(nchan):
                    in_data[sc][s, :gnd[s][0] + 1] = in_data[sc][s, gnd[s][0] + 1]
                gnd[s] = gnd[s][1:]

        # Proceed with gap evaluation
        gapsz = gnd[s] - gst[s] + 1
        if np.any(gapsz > p['gapmax']):
            biggap = gapsz > p['gapmax']
            bgnd = np.hstack((np.where(drng[s, :])[0][0] - 1, gnd[s][biggap]))
            bgst = np.hstack((gst[s][biggap], np.where(drng[s, :])[0][-1] + 1))

            if p['startonly']:
                drng[s, bgst[0]:] = False
            else:
                mx_idx = np.argmax(bgst - bgnd)
                drng[s, :] = False
                drng[s, bgnd[mx_idx] + 1:bgst[mx_idx]] = True

            gst[s] = [g for g in gst[s] if drng[s, g]]
            gnd[s] = [g for g in gnd[s] if drng[s, g]]
            
        # Enforce minimum data range
        if np.count_nonzero(drng[s, :]) < p['dlengthmin']:
            drng[s, :] = False
            continue

        # Fill gaps with linear interpolation
        for ss in range(min(len(gnd[s]), len(gst[s]))):
            for sc in range(nchan):
                in_data[sc][s, gst[s][ss] - 1:gnd[s][ss] + 2] = \
                    np.linspace(in_data[sc][s, gst[s][ss] - 1], 
                                in_data[sc][s, gnd[s][ss] + 1], 
                                gnd[s][ss] - gst[s][ss] + 3)
    elif np.count_nonzero(drng[s, :]) < p['dlengthmin']:
        drng[s, :] = False
        continue
##########################################
out = {"data": None, "cellindex": None}
    
if p["packascell"]:
    # Find start points in each track (n is TRUE & n-1 is FALSE)
    stord, dst = np.where(drng & ~np.hstack((np.zeros((nCells, 1), dtype=bool), drng[:, :-1])))
    # Find end points in each track (n is TRUE & n+1 is FALSE)
    ndord, dnd = np.where(drng & ~np.hstack((drng[:, 1:], np.zeros((nCells, 1), dtype=bool))))

    # Sort tracks by original row number
    hasdata, sti = np.sort(stord, kind="mergesort")
    _, ndi = np.sort(ndord, kind="mergesort")

    # Apply sorting to time starts and ends
    timei = np.array([dst[sti], dnd[ndi]]).T
    
    # Extract target data (excluding data to be trimmed)
    in_data = [
        [x[hasdata][i][timei[i, 0]:timei[i, 1]+1] for i in range(len(hasdata))]
        for x in in_data
    ]

    out["time"] = timei
    out["cellindex"] = np.where(hasdata)[0]
else:
    hasdata = np.any(drng, axis=1)
    out["cellindex"] = np.where(hasdata)[0]

    for sc in range(len(in_data)):
        in_data[sc][~drng] = np.nan
        in_data[sc] = in_data[sc][hasdata, :]

out["data"] = in_data
