diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index c54604b..bdd0966 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -6,6 +6,7 @@ import numpy as np from numpy.lib.recfunctions import append_fields from pandas import DataFrame, RangeIndex +import pandas as pd from root_numpy import root2array, list_trees import fnmatch from root_numpy import list_branches @@ -312,6 +313,15 @@ def convert_to_dataframe(array, start_index=None): assert len(columns) == len(df.columns), (columns, df.columns) df = df.reindex_axis(columns, axis=1, copy=False) + # Convert categorical columns back to categories + for c in df.columns: + match = re.match(r'^__rpCaT\*([^\*]+)\*(True|False)\*', c) + if match: + real_name, ordered = match.groups() + categories = c.split('*')[3:] + df[c] = pd.Categorical.from_codes(df[c], categories, ordered={'True': True, 'False': False}[ordered]) + df.rename(index=str, columns={c: real_name}, inplace=True) + return df @@ -353,12 +363,25 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg from root_numpy import array2root # We don't want to modify the user's DataFrame here, so we make a shallow copy df_ = df.copy(deep=False) + if store_index: name = df_.index.name if name is None: # Handle the case where the index has no name name = '' df_['__index__' + name] = df_.index + + # Convert categorical columns into something root_numpy can serialise + for col in df_.select_dtypes(['category']).columns: + name_components = ['__rpCaT', col, str(df_[col].cat.ordered)] + name_components.extend(df_[col].cat.categories) + if ['*' not in c for c in name_components]: + sep = '*' + else: + raise ValueError('Unable to find suitable separator for columns') + df_[col] = df_[col].cat.codes + df_.rename(index=str, columns={col: sep.join(name_components)}, inplace=True) + arr = df_.to_records(index=False) array2root(arr, path, key, mode=mode, *args, **kwargs)