In [1]:
from config import cfg
import os
import numpy as np
import pandas as pd
import mne
# from util import read_emg, read_manus

In [2]:
def read_emg(path):
    '''
    read emg signal from edf file
    Input: path to edf file
    Output: emg dataframe
    '''
    raw = mne.io.read_raw_edf(path, preload=True)

    # to dataframe
    emg_df = raw.to_data_frame()

    # change time to datetime
    emg_df['time'] = pd.to_datetime(emg_df['time']*1000, unit='ms', origin=ExpTimes.emg_start_time)

    # filterout the time column for only ms resolution
    emg_df['time'] = emg_df['time'].dt.floor('ms')

    # resample emg data to 150 Hz
    emg_df = emg_df.set_index('time')
    emg_df = emg_df.resample('8ms', origin='end').mean()

    return emg_df

def read_manus(path):
    '''
    Read data from manus glove
    Input: path to csv file
    Output: manus dataframe
    '''

    fingers = ['Thumb', 'Index', 'Middle', 'Ring', 'Pinky']
    key_points = ['MCP', 'DIP', 'PIP', 'CMC']
    movement = ['Spread', 'Flex']

    # Additional columns for manus
    pinch_columns = ['Pinch_ThumbToIndex', 'Pinch_ThumbToMiddle', 'Pinch_ThumbToRing', 'Pinch_ThumbToPinky']
    time_column = ['time']

    valid_columns = key_points + pinch_columns + time_column
    
    
    manus_df = pd.read_csv(path)

    #rename Elapsed_Time_In_Milliseconds to time
    manus_df.rename(columns={'Elapsed_Time_In_Milliseconds': 'time'}, inplace=True)

    # Convert time to datetime and drop values l
    manus_df['time'] = pd.to_datetime(manus_df['time'], unit='ms', origin=ExpTimes.manus_start_time)


    # remove acceleration and velocity columns
    acc_vel_col = [item for item in manus_df.columns if 'Acceleration' in item or 'Velocity' in item]
    manus_df.drop(columns=acc_vel_col, inplace=True)

    #drop unused columns
    unused_columns = ['Time', 'Frame'] + manus_df.filter(regex='_[X/Y/Z]', axis=1).columns.tolist()
    manus_df.drop(columns=unused_columns, inplace=True)

    # assert sorted(list(manus_df.columns) )== sorted(valid_columns), 'Columns are not valid'

    # set time as index
    manus_df = manus_df.set_index('time')

    return manus_df 

In [3]:
from datetime import datetime
class ExpTimes:
    refernce_time = datetime.strptime('2023-10-02 14:59:55.627000', '%Y-%m-%d %H:%M:%S.%f')
    manus_start_time = datetime.strptime('2023-10-02 14:59:20.799000', '%Y-%m-%d %H:%M:%S.%f')
    emg_start_time = datetime.strptime('2023-10-02 14:59:55.627000', '%Y-%m-%d %H:%M:%S.%f')
    video_Start_time = datetime.strptime('2023-10-02 14:59:55.628000', '%Y-%m-%d %H:%M:%S.%f')

    manus_columns = ['Pinch_ThumbToIndex','Pinch_ThumbToMiddle', 'Pinch_ThumbToRing',
                     'Pinch_ThumbToPinky', 'Thumb_CMC_Spread', 'Thumb_CMC_Flex', 'Thumb_PIP_Flex', 'Thumb_DIP_Flex',
                     'Index_MCP_Spread', 'Index_MCP_Flex', 'Index_PIP_Flex', 'Index_DIP_Flex', 'Middle_MCP_Spread',
                     'Middle_MCP_Flex', 'Middle_PIP_Flex', 'Middle_DIP_Flex', 'Ring_MCP_Spread', 'Ring_MCP_Flex',
                     'Ring_PIP_Flex', 'Ring_DIP_Flex', 'Pinky_MCP_Spread', 'Pinky_MCP_Flex', 'Pinky_PIP_Flex',
                     'Pinky_DIP_Flex','time']


In [4]:
# edf_path = os.path.join(cfg.DATA.PATH, 'test 2023-10-02 14-59-55-627.edf')
# manus_path = os.path.join(cfg.DATA.PATH, 'Untitled_2023-10-02_15-24-12_YH_lab_R.csv')

# #read data
# df = read_data([manus_path, edf_path])
# df.head()

In [7]:
# make a tesor dataset class
from torch.utils.data import Dataset
import torch
DATA_SOURCES = {
    'manus': read_manus,
    'emg': read_emg,
}
class EMGDataset(Dataset):
    def __init__(self, data_path, label_path, transform=None, data_source='emg', label_source='manus'):

        self.data_path = data_path
        self.label_path = label_path

        self.transform = transform

        self.data_source = data_source # emg or imu
        self.label_source = label_source   # manus, video, or ultraleap 

        self.emg_columns = ['channel {}'.format(i) for i in range(16)]
        self.mauns_columns = ['Pinch_ThumbToIndex','Pinch_ThumbToMiddle', 'Pinch_ThumbToRing',
                        'Pinch_ThumbToPinky', 'Thumb_CMC_Spread', 'Thumb_CMC_Flex', 'Thumb_PIP_Flex', 'Thumb_DIP_Flex',
                        'Index_MCP_Spread', 'Index_MCP_Flex', 'Index_PIP_Flex', 'Index_DIP_Flex', 'Middle_MCP_Spread',
                        'Middle_MCP_Flex', 'Middle_PIP_Flex', 'Middle_DIP_Flex', 'Ring_MCP_Spread', 'Ring_MCP_Flex',
                        'Ring_PIP_Flex', 'Ring_DIP_Flex', 'Pinky_MCP_Spread', 'Pinky_MCP_Flex', 'Pinky_PIP_Flex',
                        'Pinky_DIP_Flex','time']
        
        self.prepare_data()
        
    def prepare_data(self):
        data =  DATA_SOURCES[self.data_source](self.data_path)
        label = DATA_SOURCES[self.label_source](self.label_path)
        
        # set the start and end of experiment
        start_time = max(min(data.index), min(label.index))
        end_time = min(max(data.index), max(label.index))

        # select only the data between start and end time
        data = data.loc[start_time:end_time]
        label = label.loc[start_time:end_time]

        # make sure the dataframes are of the same length for the merge
        df = pd.merge_asof(data, label, on='time', direction='nearest')

        print(f'df shape: {df.shape}\n data shape: {data.shape}\n label shape: {label.shape}')
        assert df.shape[0] == data.shape[0] & df.shape[0] == label.shape[0], 'Dataframes are not of the same length'
        del df

        #reset index to numeric values
        data.reset_index(drop=True, inplace=True)
        label.reset_index(drop=True, inplace=True)

        # convert to tensor
        self.data = torch.tensor(data.values)
        self.label = torch.tensor(label.values)


    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.label[idx]
        return data, label
    # read emg signal
    
    
    @staticmethod
    def read_manus(path):
        fingers = ['Thumb', 'Index', 'Middle', 'Ring', 'Pinky']
        key_points = ['MCP', 'DIP', 'PIP', 'CMC']
        movement = ['Spread', 'Flex']

        # Additional columns for manus
        pinch_columns = ['Pinch_ThumbToIndex', 'Pinch_ThumbToMiddle', 'Pinch_ThumbToRing', 'Pinch_ThumbToPinky']
        time_column = ['time']

        valid_columns = key_points + pinch_columns + time_column
        
        
        manus_df = pd.read_csv(path)

        #rename Elapsed_Time_In_Milliseconds to time
        manus_df.rename(columns={'Elapsed_Time_In_Milliseconds': 'time'}, inplace=True)

        # Convert time to datetime and drop values l
        manus_df['time'] = pd.to_datetime(manus_df['time'], unit='ms', origin=ExpTimes.manus_start_time)


        # remove acceleration and velocity columns
        acc_vel_col = [item for item in manus_df.columns if 'Acceleration' in item or 'Velocity' in item]
        manus_df.drop(columns=acc_vel_col, inplace=True)

        #drop unused columns
        unused_columns = ['Time', 'Frame'] + manus_df.filter(regex='_[X/Y/Z]', axis=1).columns.tolist()
        manus_df.drop(columns=unused_columns, inplace=True)

        # assert sorted(list(manus_df.columns) )== sorted(valid_columns), 'Columns are not valid'

        # set time as index
        manus_df = manus_df.set_index('time')

        return manus_df


In [8]:
manus_path = os.path.join(cfg.DATA.PATH, 'Untitled_2023-10-02_15-24-12_YH_lab_R.csv')
edf_path = os.path.join(cfg.DATA.PATH, 'test 2023-10-02 14-59-55-627.edf')
dataset = EMGDataset(edf_path, manus_path)

Extracting EDF parameters from /Users/rufaelmarew/Documents/tau/finger_pose_estimation/dataset/test 2023-10-02 14-59-55-627.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 345249  =      0.000 ...  1380.996 secs...
df shape: (164650, 47)
 data shape: (164650, 22)
 label shape: (164650, 24)
                     time  Channel 0  Channel 1  Channel 2  Channel 3  \
0 2023-10-02 14:59:55.631  25.733771   3.077092  22.661679   1.541046   
1 2023-10-02 14:59:55.639  -5.755173 -18.811564  -4.987150 -17.275518   
2 2023-10-02 14:59:55.647 -17.659529  42.246266  19.589587 -16.123483   
3 2023-10-02 14:59:55.655  -4.603138  98.311945  21.509644 -21.115633   
4 2023-10-02 14:59:55.663  -8.443253  56.838703  42.246266 -10.747322   

    Channel 4  Channel 5  Channel 6   Channel 7   Channel 8  ...  \
0  102.536072  28.037840  25.733771  196.234879  122.504670  ...   
1   79.495382  -2.683081  33.414001  254.604628  122.888682  ...   
2   87.559

In [9]:
dataset[0][1].shape

torch.Size([24])