Skip to content

Commit

Permalink
Multiple clusters support for RNTuple (#682)
Browse files Browse the repository at this point in the history
* fix multiple cluster

* rename variables for more consistency


* fix bit reading

* clean up

* don't assume numpy has bitorder='little'
  • Loading branch information
Moelf committed Aug 21, 2022
1 parent 613ee54 commit e7e8be1
Showing 1 changed file with 78 additions and 47 deletions.
125 changes: 78 additions & 47 deletions src/uproot/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,15 @@ def page_list_envelopes(self):

def base_col_form(self, cr, col_id, parameters=None):
form_key = f"column-{col_id}"
if cr.type == uproot.const.rntuple_role_union:
dtype_byte = cr.type
if dtype_byte == uproot.const.rntuple_role_union:
return form_key
elif cr.type > uproot.const.rntuple_role_struct:
elif dtype_byte > uproot.const.rntuple_role_struct:
dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
if dt_str == "bit":
dt_str = "bool"
return ak._v2.forms.NumpyForm(
uproot.const.rntuple_col_num_to_dtype_dict[cr.type],
dt_str,
form_key=form_key,
parameters=parameters,
)
Expand Down Expand Up @@ -236,12 +240,12 @@ def field_form(self, this_id, seen):
field_records = self.header.field_records
this_record = field_records[this_id]
seen.append(this_id)
sr = this_record.struct_role
if sr == uproot.const.rntuple_role_leaf:
# base case of recursive
# n.b. the split may be in column
structural_role = this_record.struct_role
if structural_role == uproot.const.rntuple_role_leaf:
# base case of recursion
# n.b. the split may happen in column
return self.col_form(this_id, this_record.type_name)
elif sr == uproot.const.rntuple_role_vector:
elif structural_role == uproot.const.rntuple_role_vector:
keyname = self.col_form(this_id)
child_id = next(
filter(
Expand All @@ -251,7 +255,7 @@ def field_form(self, this_id, seen):
)
inner = self.field_form(child_id, seen)
return ak._v2.forms.ListOffsetForm("u32", inner, form_key=keyname)
elif sr == uproot.const.rntuple_role_struct:
elif structural_role == uproot.const.rntuple_role_struct:
newids = []
for i, fr in enumerate(field_records):
if i not in seen and fr.parent_field_id == this_id:
Expand All @@ -260,7 +264,7 @@ def field_form(self, this_id, seen):
recordlist = [self.field_form(i, seen) for i in newids]
namelist = [field_records[i].field_name for i in newids]
return ak._v2.forms.RecordForm(recordlist, namelist, form_key="whatever")
elif sr == uproot.const.rntuple_role_union:
elif structural_role == uproot.const.rntuple_role_union:
keyname = self.col_form(this_id)
newids = []
for i, fr in enumerate(field_records):
Expand All @@ -269,7 +273,7 @@ def field_form(self, this_id, seen):
recordlist = [self.field_form(i, seen) for i in newids]
return ak._v2.forms.UnionForm("i8", "i64", recordlist, form_key=keyname)
else:
# everything should recursive above this branch
# everything should recurse above this branch
raise AssertionError("this should be unreachable")

def to_akform(self):
Expand All @@ -289,39 +293,60 @@ def pagelist(self, listdesc):
pages = listdesc.reader.read(listdesc.chunk, local_cursor, listdesc.context)
return pages

def read_pagedesc(self, destination, start, stop, desc, dtype):
num_elements = desc.num_elements
def read_pagedesc(self, destination, desc, dtype_str, dtype):
loc = desc.locator
cursor = uproot.source.cursor.Cursor(loc.offset)
context = {}
uncomp_size = num_elements * dtype.itemsize
# bool in RNTuple is always stored as bits
isbit = dtype_str == "bit"
len_divider = 8 if isbit else 1
num_elements = len(destination)
num_elements_toread = int(numpy.ceil(num_elements / len_divider))
uncomp_size = num_elements_toread * dtype.itemsize
decomp_chunk = self.read_locator(loc, uncomp_size, cursor, context)
destination[start:stop] = cursor.array(
decomp_chunk, num_elements, dtype, context, move=False
content = cursor.array(
decomp_chunk, num_elements_toread, dtype, context, move=False
)
if isbit:
content = (
numpy.unpackbits(content.view(dtype=numpy.uint8))
.reshape(-1, 8)[:, ::-1]
.reshape(-1)
)

# needed to chop off extra bits incase we used `unpackbits`
destination[:] = content[:num_elements]

def read_col_pages(self, ncol, cluster_range):
return numpy.concatenate(
[self.read_col_page(ncol, i) for i in cluster_range], axis=0
)

def read_col_page(self, ncol, entry_start):
ngroup = self.which_colgroup(ncol)
linklist = self.page_list_envelopes.pagelinklist[ngroup]
def read_col_page(self, ncol, cluster_i):
linklist = self.page_list_envelopes.pagelinklist[cluster_i]
link = linklist[ncol]
pagelist = self.pagelist(link)
dtype_byte = self.column_records[ncol].type
dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
T = numpy.dtype(dt_str)
dtype_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
if dtype_str == "bit":
dtype = numpy.dtype("bool")
else:
dtype = numpy.dtype(dtype_str)

# FIXME vector read
# n.b. it's possible pagelist is empty
if not pagelist:
return numpy.empty(0, T)
return numpy.empty(0, dtype)
total_len = numpy.sum([desc.num_elements for desc in pagelist])
res = numpy.empty(total_len, T)
res = numpy.empty(total_len, dtype)
tracker = 0
for p in pagelist:
tracker_end = tracker + p.num_elements
self.read_pagedesc(res, tracker, tracker_end, p, T)
for page_desc in pagelist:
n_elements = page_desc.num_elements
tracker_end = tracker + n_elements
self.read_pagedesc(res[tracker:tracker_end], page_desc, dtype_str, dtype)
tracker = tracker_end

if dtype_byte <= 2:
if dtype_byte <= uproot.const.rntuple_col_type_to_num_dict["index32"]:
res = numpy.insert(res, 0, 0) # for offsets
return res

Expand All @@ -335,30 +360,44 @@ def arrays(
array_cache=None,
):
entry_stop = entry_stop or self._length

clusters = self.cluster_summaries
if len(clusters) != 1:
raise (RuntimeError("Not implemented"))
# FIXME we assume cluster starts at entry 0, i.e only one cluster
L = clusters[0].num_entries
cluster_starts = numpy.array([c.num_first_entry for c in clusters])

start_cluster_idx = (
numpy.searchsorted(cluster_starts, entry_start, side="right") - 1
)
stop_cluster_idx = numpy.searchsorted(cluster_starts, entry_stop, side="right")
cluster_num_entries = numpy.sum(
[c.num_entries for c in clusters[start_cluster_idx:stop_cluster_idx]]
)

form = self.to_akform().select_columns(filter_names)
# only read columns mentioned in the awkward form
target_cols = []
D = {}
container_dict = {}
_recursive_find(form, target_cols)
for i, cr in enumerate(self.column_records):
key = f"column-{i}"
dtype_byte = cr.type
if key in target_cols:
content = self.read_col_page(i, L)
if cr.type == uproot.const.rntuple_role_union:
content = self.read_col_pages(
i, range(start_cluster_idx, stop_cluster_idx)
)
if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]:
kindex, tags = _split_switch_bits(content)
D[f"{key}-index"] = kindex
D[f"{key}-tags"] = tags
container_dict[f"{key}-index"] = kindex
container_dict[f"{key}-tags"] = tags
else:
# don't distinguish data and offsets
D[f"{key}-data"] = content
D[f"{key}-offsets"] = content
return ak._v2.from_buffers(form, L, Container(D))[entry_start:entry_stop]
container_dict[f"{key}-data"] = content
container_dict[f"{key}-offsets"] = content
cluster_offset = cluster_starts[start_cluster_idx]
entry_start -= cluster_offset
entry_stop -= cluster_offset
return ak._v2.from_buffers(form, cluster_num_entries, container_dict)[
entry_start:entry_stop
]


# Supporting function and classes
Expand All @@ -379,14 +418,6 @@ def _recursive_find(form, res):
_recursive_find(form.content, res)


class Container:
def __init__(self, D):
self._dict = D

def __getitem__(self, name):
return self._dict[name]


# https://github.com/jblomer/root/blob/ntuple-binary-format-v1/tree/ntuple/v7/doc/specifications.md#page-list-envelope
class PageDescription:
def read(self, chunk, cursor, context):
Expand Down

0 comments on commit e7e8be1

Please sign in to comment.