In [1]:
from sktime.datasets import load_from_tsfile_to_dataframe
import pandas as pd
import numpy as np

def interpolate_missing(y):
    """
    Replaces NaN values in pd.Series `y` using linear interpolation
    """
    if y.isna().any():
        y = y.interpolate(method='linear', limit_direction='both')
    return y

def subsample(y, limit=256, factor=2):
    """
    If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor
    """
    if len(y) > limit:
        return y[::factor].reset_index(drop=True)
    return y

In [4]:
df, labels = load_from_tsfile_to_dataframe('/home/user/suzhao/BehaviorDL/dataset/Heartbeat/Heartbeat_TRAIN.ts', 
                                           return_separate_X_and_y=True,
                                           replace_missing_vals_with='NaN')

labels = pd.Series(labels, dtype="category")
class_names = labels.cat.categories
labels_df = pd.DataFrame(labels.cat.codes,
                            dtype=np.int8)  # int8-32 gives an error when using nn.CrossEntropyLoss
lengths = df.applymap(
    lambda x: len(x)).values  # (num_samples, num_dimensions) array containing the length of each series

horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))

if np.sum(horiz_diffs) > 0:  # if any row (sample) has varying length across dimensions
    df = df.applymap(subsample)


lengths = df.applymap(lambda x: len(x)).values
vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
if np.sum(vert_diffs) > 0:  # if any column (dimension) has varying length across samples
    max_seq_len = int(np.max(lengths[:, 0]))
else:
    max_seq_len = lengths[0, 0]

# First create a (seq_len, feat_dim) dataframe for each sample, indexed by a single integer ("ID" of the sample)
# Then concatenate into a (num_samples * seq_len, feat_dim) dataframe, with multiple rows corresponding to the
# sample index (i.e. the same scheme as all datasets in this project)

df = pd.concat((pd.DataFrame({col: df.loc[row, col] for col in df.columns}).reset_index(drop=True).set_index(
    pd.Series(lengths[row, 0] * [row])) for row in range(df.shape[0])), axis=0)

# Replace NaN values
grp = df.groupby(by=df.index)
df = grp.transform(interpolate_missing)

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_51,dim_52,dim_53,dim_54,dim_55,dim_56,dim_57,dim_58,dim_59,dim_60
0,0.000949,0.001288,0.000529,0.000563,0.000980,0.001009,0.001464,0.002277,0.001552,0.000708,...,0.001950,0.020030,0.035233,0.041325,0.048396,0.039963,0.078962,0.078312,0.045608,0.121070
0,0.001488,0.001140,0.001635,0.001082,0.001680,0.001782,0.000935,0.001225,0.001347,0.000268,...,0.014218,0.012354,0.022914,0.025636,0.021956,0.036272,0.085931,0.145680,0.119800,0.133850
0,0.000314,0.000430,0.002146,0.000610,0.001279,0.002249,0.000397,0.001563,0.000330,0.000971,...,0.039978,0.036798,0.036099,0.041923,0.052983,0.039156,0.110610,0.238050,0.178130,0.040077
0,0.000995,0.000532,0.001744,0.000445,0.001592,0.000938,0.000240,0.001269,0.000169,0.000652,...,0.052157,0.053817,0.050959,0.065993,0.096772,0.083629,0.165070,0.265630,0.150210,0.061956
0,0.002099,0.001492,0.001424,0.000796,0.001521,0.000591,0.000554,0.001469,0.001343,0.000715,...,0.030683,0.041912,0.026441,0.049231,0.078987,0.149420,0.245120,0.186530,0.066619,0.070956
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
203,0.005072,0.002829,0.001588,0.006018,0.011885,0.002574,0.006266,0.006309,0.013291,0.017248,...,0.011872,0.033656,0.033050,0.033799,0.049153,0.071279,0.067396,0.081201,0.131550,0.031515
203,0.005831,0.002387,0.002514,0.004973,0.011910,0.006155,0.004932,0.002476,0.011635,0.012985,...,0.007311,0.036839,0.032488,0.020085,0.025653,0.038853,0.127560,0.183250,0.177930,0.066749
203,0.007455,0.006783,0.004723,0.006756,0.012059,0.008155,0.002254,0.002549,0.007925,0.005576,...,0.016262,0.032110,0.026859,0.022696,0.042071,0.048182,0.201120,0.280130,0.158690,0.101930
203,0.008124,0.009142,0.005436,0.005101,0.007947,0.007861,0.000964,0.001796,0.004098,0.006673,...,0.013833,0.014482,0.044743,0.050033,0.044775,0.102920,0.226500,0.270960,0.068071,0.062402
