In [None]:
import os
import h5py
import pandas as pd
import numpy as np
from scipy.interpolate import interp1d

# Define the input and output directories
input_dir = r"S:\Lab_Member\Tobi\Experiments\Exp9_Social-Stress\Raw Data\Behavior\B6\SocP\SLEAP\geom"
output_dir = r"S:\Lab_Member\Tobi\Experiments\Exp9_Social-Stress\Raw Data\Behavior\B6\SocP\SLEAP\geom"

# Define geom indices
geom_indices = {
    0: 'tl',
    1: 'tr',
    2: 'br',
    3: 'bl',
    4: 'socl',
    5: 'socr'
}

def fill_missing(Y, kind="linear"):
    """Fills missing values independently along each dimension after the first."""
    if len(Y) == 0:
        return Y
    else:
        # Store initial shape.
        initial_shape = Y.shape
        # Flatten after first dim.
        Y = Y.reshape((initial_shape[0], -1))
        # Interpolate along each slice.
        for i in range(Y.shape[-1]):
            y = Y[:, i]
            # Fill first or last NaNs with column median
            if np.isnan(y[0]):
                y[0] = np.nanmedian(y)
            if np.isnan(y[-1]):
                y[-1] = np.nanmedian(y)
            # Interpolate all NaNs between non-NaN values
            non_nans = np.flatnonzero(~np.isnan(y))
            if non_nans.size > 1:
                f = interp1d(non_nans, y[non_nans], kind=kind, fill_value=np.nan, bounds_error=False)
                xq = np.flatnonzero(np.isnan(y))
                y[xq] = f(xq)
            # Save slice
            Y[:, i] = y
        # Restore to initial shape.
        Y = Y.reshape(initial_shape)
        return Y

# Loop over all files in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith(".h5"):
        # Construct the full input and output file paths
        input_path = os.path.join(input_dir, filename)
        output_filename = os.path.splitext(filename)[0] + '_locs.csv'
        output_path = os.path.join(output_dir, output_filename)

        # Load the data from the input file
        with h5py.File(input_path, "r") as f:
            locations = f["tracks"][:].T

        # Fill missing data in the locations array
        locations = fill_missing(locations)

        # Extract geom locations
        geom_locs = {}
        for index, geom_name in geom_indices.items():
            geom_locs[geom_name] = locations[:, index, :, :]

        # Initialize the DataFrame
        df = pd.DataFrame()

        # Loop through the geoms
        for geom in geom_indices.values():
            x = geom_locs[geom][:, 0] # retrieve the x coordinate
            y = geom_locs[geom][:, 1] # retrieve the y coordinate
            x, y = x.ravel(), y.ravel() # unravel the arrays
            df[f"{geom}_x"] = x # add the x coordinate to the DataFrame
            df[f"{geom}_y"] = y # add the y coordinate to the DataFrame

        # Save the DataFrame as a csv file
        df.to_csv(output_path, index=False)

        print(f"Processed file {input_path}")