In [57]:
import xarray as xr
import numpy as np
from pathlib import Path
from shared.utilities import pad_to_max_sample_length

In [62]:
data_path = Path("data/sat1/stage_data.nc")
output_path = Path("data/sat1/split_stage_data.nc")
stage_data = xr.load_dataset(data_path)

In [59]:
# Must convert to numpy since np.where does not work in this case on XArray
label_data = stage_data.labels.to_numpy()

In [60]:
segments = []

changes = np.array(np.where(label_data[:, :, :-1] != label_data[:, :, 1:]))
changes[2] += 1
last_change = None

for participant, epoch, change in zip(changes[0], changes[1], changes[2]):
    if last_change is None:
        last_change = change
    else:
        # Dont take segment ending at one epoch and beginning in the next
        if last_change < change:
            segment = stage_data.isel(
                participant=[participant],  # List to retain dimension in segment
                epochs=[epoch],  # List to retain dimension in segment
                samples=slice(last_change, change),
            )
            # Ignore start/end segments containing only empty strings
            if np.any(segment.labels != ""):
                label = segment.labels[0, 0, 0].item()
                segment = (
                    segment["data"]
                    .expand_dims({"labels": 1}, axis=2)
                    .assign_coords(labels=[label])
                )
                segment["samples"] = np.arange(0, len(segment["samples"]))
                segments.append(segment)
        last_change = change

In [63]:
combined_segments = xr.combine_by_coords(segments)
combined_segments.to_netcdf(output_path)

# Deprecated

In [None]:
# %%time
# 2488 labels/segments
dimensions = stage_data.labels.shape
segments = []
labels = []
for participant in range(dimensions[0]):
    for epoch in range(dimensions[1]):
        trial = stage_data.isel(participant=participant, epochs=epoch)
        # If trial contains anything other than empty strings, at least one processing stage has been identified and labelled
        if np.any(trial.labels != ""):
            np_labels = trial.labels.to_numpy()
            # Find locations where string changes in a sequence of strings to find changes in stages
            changes = np.where(np_labels[:-1] != np_labels[1:])[0] + 1
            # Get EEG data for participant/epoch

            # Create list of EEG data splits
            splits = np.split(trial.data, changes, 1)
            label_splits = np.split(trial.labels, changes, 0)
            for split, label_split in zip(splits, label_splits):
                if np.any(label_split != ""):
                    # Sanity check
                    if np.any(np.isnan(split)):
                        print(f"NAN FOUND participant {participant}, epoch {epoch}")
                        print(f"Label: {label_split[0].item()}")
                        print(split.to_numpy()[0, :])
                        raise ValueError(
                            "Split found that is fully labeled but still contains NaNs in data"
                        )
                    # Add split and label to lists
                    segments.append(split.data)
                    labels.append(label_split[0].item())
        else:
            print(f"Trial unusable: participant {participant}, epoch {epoch}")

# Pad to max_sample_length, can differ per data set or subset
max_sample_length = max(segment.shape[1] for segment in segments)
segments = np.stack(
    [pad_to_max_sample_length(segment, max_sample_length) for segment in segments]
)