This repository has been archived by the owner on Jan 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 35
/
readwrite.py
317 lines (260 loc) · 11.1 KB
/
readwrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
"""
A module that extends pandas to support the ROOT data format.
"""
import numpy as np
from numpy.lib.recfunctions import append_fields
from pandas import DataFrame, RangeIndex
from root_numpy import root2array, list_trees
from fnmatch import fnmatch
from root_numpy import list_branches
from root_numpy.extern.six import string_types
import itertools
from math import ceil
import re
import ROOT
import warnings
from .utils import stretch
__all__ = [
'read_root',
'to_root',
]
NOEXPAND_PREFIX = 'noexpand:'
def expand_braces(orig):
r = r'.*?(\{.+[^\\]\})'
p = re.compile(r)
s = orig[:]
res = list()
m = p.search(s)
if m is not None:
sub = m.group(1)
open_brace = s.find(sub)
close_brace = open_brace + len(sub) - 1
if sub.find(',') != -1:
for pat in sub[1:-1].split(','):
res.extend(expand_braces(s[:open_brace] + pat + s[close_brace+1:]))
else:
res.extend(expand_braces(s[:open_brace] + sub.replace('}', '\\}') + s[close_brace+1:]))
else:
res.append(s.replace('\\}', '}'))
return list(set(res))
def get_nonscalar_columns(array):
first_row = array[0]
bad_cols = np.array([x.ndim != 0 for x in first_row])
col_names = np.array(array.dtype.names)
bad_names = col_names[bad_cols]
return list(bad_names)
def get_matching_variables(branches, patterns, fail=True):
selected = []
for p in patterns:
found = False
for b in branches:
if fnmatch(b, p):
found = True
if fnmatch(b, p) and b not in selected:
selected.append(b)
if not found and fail:
raise ValueError("Pattern '{}' didn't match any branch".format(p))
return selected
def filter_noexpand_columns(columns):
"""Return columns not containing and containing the noexpand prefix.
Parameters
----------
columns: sequence of str
A sequence of strings to be split
Returns
-------
Two lists, the first containing strings without the noexpand prefix, the
second containing those that do with the prefix filtered out.
"""
prefix_len = len(NOEXPAND_PREFIX)
noexpand = [c[prefix_len:] for c in columns if c.startswith(NOEXPAND_PREFIX)]
other = [c for c in columns if not c.startswith(NOEXPAND_PREFIX)]
return other, noexpand
def do_flatten(arr, flatten):
if flatten is True:
warnings.warn(" The option flatten=True is deprecated. Please specify the branches you would like "
"to flatten in a list: flatten=['foo', 'bar']", FutureWarning)
arr_, idx = stretch(arr, return_indices=True)
else:
nonscalar = get_nonscalar_columns(arr)
fields = [x for x in arr.dtype.names if (x not in nonscalar or x in flatten)]
for col in flatten:
if col in nonscalar:
pass
elif col in fields:
raise ValueError("Requested to flatten {col} but it has a scalar type"
.format(col=col))
else:
raise ValueError("Requested to flatten {col} but it wasn't loaded from the input file"
.format(col=col))
arr_, idx = stretch(arr, fields=fields, return_indices=True)
arr = append_fields(arr_, '__array_index', idx, usemask=False, asrecarray=True)
return arr
def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=None, flatten=False, *args, **kwargs):
"""
Read a ROOT file, or list of ROOT files, into a pandas DataFrame.
Further *args and *kwargs are passed to root_numpy's root2array.
If the root file contains a branch matching __index__*, it will become the DataFrame's index.
Parameters
----------
paths: string or list
The path(s) to the root file(s)
key: string
The key of the tree to load.
columns: str or sequence of str
A sequence of shell-patterns (can contain *, ?, [] or {}). Matching columns are read.
The columns beginning with `noexpand:` are not interpreted as shell-patterns,
allowing formula columns such as `noexpand:2*x`. The column in the returned DataFrame
will not have the `noexpand:` prefix.
ignore: str or sequence of str
A sequence of shell-patterns (can contain *, ?, [] or {}). All matching columns are ignored (overriding the columns argument).
chunksize: int
If this parameter is specified, an iterator is returned that yields DataFrames with `chunksize` rows.
where: str
Only rows that match the expression will be read.
flatten: sequence of str
A sequence of column names. Will use root_numpy.stretch to flatten arrays in the specified columns into
individual entries. All arrays specified in the columns must have the same length for this to work.
Be careful if you combine this with chunksize, as chunksize will refer to the number of unflattened entries,
so you will be iterating over a number of entries that is potentially larger than chunksize.
The index of each element within its former array will be saved in the __array_index column.
Returns
-------
DataFrame created from matching data in the specified TTree
Notes
-----
>>> df = read_root('test.root', 'MyTree', columns=['A{B,C}*', 'D'], where='ABB > 100')
"""
if not isinstance(paths, list):
paths = [paths]
# Use a single file to search for trees and branches
seed_path = paths[0]
if not key:
trees = list_trees(seed_path)
if len(trees) == 1:
key = trees[0]
elif len(trees) == 0:
raise ValueError('No trees found in {}'.format(seed_path))
else:
raise ValueError('More than one tree found in {}'.format(seed_path))
branches = list_branches(seed_path, key)
if not columns:
all_vars = branches
else:
if isinstance(columns, string_types):
columns = [columns]
# __index__* is always loaded if it exists
# XXX Figure out what should happen with multi-dimensional indices
index_branches = list(filter(lambda x: x.startswith('__index__'), branches))
if index_branches:
columns = columns[:]
columns.append(index_branches[0])
columns, noexpand = filter_noexpand_columns(columns)
columns = list(itertools.chain.from_iterable(list(map(expand_braces, columns))))
all_vars = get_matching_variables(branches, columns) + noexpand
if ignore:
if isinstance(ignore, string_types):
ignore = [ignore]
ignored = get_matching_variables(branches, ignore, fail=False)
ignored = list(itertools.chain.from_iterable(list(map(expand_braces, ignored))))
if any(map(lambda x: x.startswith('__index__'), ignored)):
raise ValueError('__index__* branch is being ignored!')
for var in ignored:
all_vars.remove(var)
if chunksize:
tchain = ROOT.TChain(key)
for path in paths:
tchain.Add(path)
n_entries = tchain.GetEntries()
# XXX could explicitly clean up the opened TFiles with TChain::Reset
def genchunks():
current_index = 0
for chunk in range(int(ceil(float(n_entries) / chunksize))):
arr = root2array(paths, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
if flatten:
arr = do_flatten(arr, flatten)
yield convert_to_dataframe(arr, start_index=current_index)
current_index += len(arr)
return genchunks()
arr = root2array(paths, key, all_vars, selection=where, *args, **kwargs)
if flatten:
arr = do_flatten(arr, flatten)
return convert_to_dataframe(arr)
def convert_to_dataframe(array, start_index=None):
nonscalar_columns = get_nonscalar_columns(array)
# Columns containing 2D arrays can't be loaded so convert them 1D arrays of arrays
reshaped_columns = {}
for col in nonscalar_columns:
if array[col].ndim >= 2:
reshaped = np.zeros(len(array[col]), dtype='O')
for i, row in enumerate(array[col]):
reshaped[i] = row
reshaped_columns[col] = reshaped
indices = list(filter(lambda x: x.startswith('__index__'), array.dtype.names))
if len(indices) == 0:
index = None
if start_index is not None:
index = RangeIndex(start=start_index, stop=start_index + len(array))
df = DataFrame.from_records(array, exclude=reshaped_columns, index=index)
elif len(indices) == 1:
# We store the index under the __index__* branch, where
# * is the name of the index
df = DataFrame.from_records(array, exclude=reshaped_columns, index=indices[0])
index_name = indices[0][len('__index__'):]
if not index_name:
# None means the index has no name
index_name = None
df.index.name = index_name
else:
raise ValueError("More than one index found in file")
# Manually the columns which were reshaped
for key, reshaped in reshaped_columns.items():
df[key] = reshaped
# Reshaping can cause the order of columns to change so we have to change it back
if reshaped_columns:
# Filter to remove __index__ columns
columns = [c for c in array.dtype.names if c in df.columns]
assert len(columns) == len(df.columns), (columns, df.columns)
df = df.reindex_axis(columns, axis=1, copy=False)
return df
def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwargs):
"""
Write DataFrame to a ROOT file.
Parameters
----------
path: string
File path to new ROOT file (will be overwritten)
key: string
Name of tree that the DataFrame will be saved as
mode: string, {'w', 'a'}
Mode that the file should be opened in (default: 'w')
store_index: bool (optional, default: True)
Whether the index of the DataFrame should be stored as
an __index__* branch in the tree
Notes
-----
Further *args and *kwargs are passed to root_numpy's array2root.
>>> df = DataFrame({'x': [1,2,3], 'y': [4,5,6]})
>>> df.to_root('test.root')
The DataFrame index will be saved as a branch called '__index__*',
where * is the name of the index in the original DataFrame
"""
if mode == 'a':
mode = 'update'
elif mode == 'w':
mode = 'recreate'
else:
raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode))
from root_numpy import array2root
# We don't want to modify the user's DataFrame here, so we make a shallow copy
df_ = df.copy(deep=False)
if store_index:
name = df_.index.name
if name is None:
# Handle the case where the index has no name
name = ''
df_['__index__' + name] = df_.index
arr = df_.to_records(index=False)
array2root(arr, path, key, mode=mode, *args, **kwargs)
# Patch pandas DataFrame to support to_root method
DataFrame.to_root = to_root