Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions root_pandas/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from numpy.lib.recfunctions import append_fields
from pandas import DataFrame, RangeIndex
import pandas as pd
from root_numpy import root2array, list_trees
import fnmatch
from root_numpy import list_branches
Expand Down Expand Up @@ -312,6 +313,15 @@ def convert_to_dataframe(array, start_index=None):
assert len(columns) == len(df.columns), (columns, df.columns)
df = df.reindex_axis(columns, axis=1, copy=False)

# Convert categorical columns back to categories
for c in df.columns:
match = re.match(r'^__rpCaT\*([^\*]+)\*(True|False)\*', c)
if match:
real_name, ordered = match.groups()
categories = c.split('*')[3:]
df[c] = pd.Categorical.from_codes(df[c], categories, ordered={'True': True, 'False': False}[ordered])
df.rename(index=str, columns={c: real_name}, inplace=True)

return df


Expand Down Expand Up @@ -353,12 +363,25 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg
from root_numpy import array2root
# We don't want to modify the user's DataFrame here, so we make a shallow copy
df_ = df.copy(deep=False)

if store_index:
name = df_.index.name
if name is None:
# Handle the case where the index has no name
name = ''
df_['__index__' + name] = df_.index

# Convert categorical columns into something root_numpy can serialise
for col in df_.select_dtypes(['category']).columns:
name_components = ['__rpCaT', col, str(df_[col].cat.ordered)]
name_components.extend(df_[col].cat.categories)
if ['*' not in c for c in name_components]:
sep = '*'
else:
raise ValueError('Unable to find suitable separator for columns')
df_[col] = df_[col].cat.codes
df_.rename(index=str, columns={col: sep.join(name_components)}, inplace=True)

arr = df_.to_records(index=False)
array2root(arr, path, key, mode=mode, *args, **kwargs)

Expand Down