# Prepare Dataset
1. Read Energy df
2. Create timeslice groups

In [1]:
def read_df(file):
    """
    Reads and returns the Dataframes
    """
    
    df = pd.read_csv(file)
    
    return df
    

In [4]:
def sort_mixed_groups(df):
    """
    Within each mixed groups (hits + noise), we need to sort
    groups such that those with highest hits are selected first
    
    DataFrame --> List(int)
    """
    
    # group by groups and labels and associated counts for each label
    grouped_df = pd.DataFrame(df.groupby(['group', 'label'])['label'].count())
    grouped_df = grouped_df.rename(columns={'label':'count'})
    
    # sort groups based on highest occurance of hits
    grouped_sorted_df = grouped_df.sort_values(grouped_df.columns.tolist())\
                            .sort_index(level=1, ascending=False, sort_remaining=False)\
                            .reset_index()
    
    # Obtain list of groups with highest hits based on sorted order
    sorted_groups_list = pd.DataFrame(grouped_sorted_df.group)
    
    # Drop duplicate groups
    sorted_groups_list = sorted_groups_list.drop_duplicates()
    
    # Count the occurances of groups (should only occur once)
    sorted_groups_list['g'] = sorted_groups_list.groupby('group').cumcount()
    
    # Make copy of dataframe
    copy_df = df
    
    # Save original index positions as a column
    copy_df_indices = copy_df.reset_index()
    
    # Make a count of occurances for each group
    copy_df_indices['group_count'] = copy_df_indices.groupby('group').cumcount() 
    
    # Merge the list of groups with the partial df to obtain corresponding full dataframe
    copy_df = sorted_groups_list.merge(copy_df_indices)\
                                .set_index('index')\
                                .rename_axis(None)\
                                .drop(['group_count', 'g'], axis=1)
    
    # For each group, sort by labels within each group starting from 1 till 0
    df = copy_df.groupby(['group'], sort=False)\
                 .apply(lambda x: x.sort_values(['label'], ascending=False))\
                 .reset_index(drop=True)
    
    return df

In [45]:
def save_df(df):
    """
    Save unsampled mixed groups
    """
    
    sorted_df = sort_mixed_groups(df)
    list_groups = sorted_df.group.unique()[:SIZE]
    
    for idx in list_groups:
        file_name = "group_" +str(idx) +".xyz" 
        
        np.savetxt(output_path + file_name,
                   df[df.group == idx][['pos_x', 'pos_y', 'time', 'energy']].values)
        
        np.savetxt(output_energy_path + file_name, 
                   df[df.group == idx][['energy']].drop_duplicates().energy.tolist())

    
    print("All {0} files saved successfully in {1}!".format(len(list_groups),
                                                           output_path))

In [48]:
if __name__ == '__main__':
    import pandas as pd
    import numpy as np
    import matplotlib 
    import sys
    
    SIZE = 200

    input_file = "../../data/energy/df.csv"
    output_path = "../../data/energy/xyz/points/"
    output_energy_path = "../../data/energy/xyz/energy/"
    
    df = read_df(input_file)
    save_df(df)

All 200 files saved successfully in ../../data/energy/xyz/points/!
