diff --git a/src/uproot/models/RNTuple.py b/src/uproot/models/RNTuple.py index d6a9ce2ed..6efe454fa 100644 --- a/src/uproot/models/RNTuple.py +++ b/src/uproot/models/RNTuple.py @@ -195,11 +195,15 @@ def page_list_envelopes(self): def base_col_form(self, cr, col_id, parameters=None): form_key = f"column-{col_id}" - if cr.type == uproot.const.rntuple_role_union: + dtype_byte = cr.type + if dtype_byte == uproot.const.rntuple_role_union: return form_key - elif cr.type > uproot.const.rntuple_role_struct: + elif dtype_byte > uproot.const.rntuple_role_struct: + dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte] + if dt_str == "bit": + dt_str = "bool" return ak._v2.forms.NumpyForm( - uproot.const.rntuple_col_num_to_dtype_dict[cr.type], + dt_str, form_key=form_key, parameters=parameters, ) @@ -236,12 +240,12 @@ def field_form(self, this_id, seen): field_records = self.header.field_records this_record = field_records[this_id] seen.append(this_id) - sr = this_record.struct_role - if sr == uproot.const.rntuple_role_leaf: - # base case of recursive - # n.b. the split may be in column + structural_role = this_record.struct_role + if structural_role == uproot.const.rntuple_role_leaf: + # base case of recursion + # n.b. the split may happen in column return self.col_form(this_id, this_record.type_name) - elif sr == uproot.const.rntuple_role_vector: + elif structural_role == uproot.const.rntuple_role_vector: keyname = self.col_form(this_id) child_id = next( filter( @@ -251,7 +255,7 @@ def field_form(self, this_id, seen): ) inner = self.field_form(child_id, seen) return ak._v2.forms.ListOffsetForm("u32", inner, form_key=keyname) - elif sr == uproot.const.rntuple_role_struct: + elif structural_role == uproot.const.rntuple_role_struct: newids = [] for i, fr in enumerate(field_records): if i not in seen and fr.parent_field_id == this_id: @@ -260,7 +264,7 @@ def field_form(self, this_id, seen): recordlist = [self.field_form(i, seen) for i in newids] namelist = [field_records[i].field_name for i in newids] return ak._v2.forms.RecordForm(recordlist, namelist, form_key="whatever") - elif sr == uproot.const.rntuple_role_union: + elif structural_role == uproot.const.rntuple_role_union: keyname = self.col_form(this_id) newids = [] for i, fr in enumerate(field_records): @@ -269,7 +273,7 @@ def field_form(self, this_id, seen): recordlist = [self.field_form(i, seen) for i in newids] return ak._v2.forms.UnionForm("i8", "i64", recordlist, form_key=keyname) else: - # everything should recursive above this branch + # everything should recurse above this branch raise AssertionError("this should be unreachable") def to_akform(self): @@ -289,39 +293,60 @@ def pagelist(self, listdesc): pages = listdesc.reader.read(listdesc.chunk, local_cursor, listdesc.context) return pages - def read_pagedesc(self, destination, start, stop, desc, dtype): - num_elements = desc.num_elements + def read_pagedesc(self, destination, desc, dtype_str, dtype): loc = desc.locator cursor = uproot.source.cursor.Cursor(loc.offset) context = {} - uncomp_size = num_elements * dtype.itemsize + # bool in RNTuple is always stored as bits + isbit = dtype_str == "bit" + len_divider = 8 if isbit else 1 + num_elements = len(destination) + num_elements_toread = int(numpy.ceil(num_elements / len_divider)) + uncomp_size = num_elements_toread * dtype.itemsize decomp_chunk = self.read_locator(loc, uncomp_size, cursor, context) - destination[start:stop] = cursor.array( - decomp_chunk, num_elements, dtype, context, move=False + content = cursor.array( + decomp_chunk, num_elements_toread, dtype, context, move=False + ) + if isbit: + content = ( + numpy.unpackbits(content.view(dtype=numpy.uint8)) + .reshape(-1, 8)[:, ::-1] + .reshape(-1) + ) + + # needed to chop off extra bits incase we used `unpackbits` + destination[:] = content[:num_elements] + + def read_col_pages(self, ncol, cluster_range): + return numpy.concatenate( + [self.read_col_page(ncol, i) for i in cluster_range], axis=0 ) - def read_col_page(self, ncol, entry_start): - ngroup = self.which_colgroup(ncol) - linklist = self.page_list_envelopes.pagelinklist[ngroup] + def read_col_page(self, ncol, cluster_i): + linklist = self.page_list_envelopes.pagelinklist[cluster_i] link = linklist[ncol] pagelist = self.pagelist(link) dtype_byte = self.column_records[ncol].type - dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte] - T = numpy.dtype(dt_str) + dtype_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte] + if dtype_str == "bit": + dtype = numpy.dtype("bool") + else: + dtype = numpy.dtype(dtype_str) # FIXME vector read # n.b. it's possible pagelist is empty if not pagelist: - return numpy.empty(0, T) + return numpy.empty(0, dtype) total_len = numpy.sum([desc.num_elements for desc in pagelist]) - res = numpy.empty(total_len, T) + res = numpy.empty(total_len, dtype) tracker = 0 - for p in pagelist: - tracker_end = tracker + p.num_elements - self.read_pagedesc(res, tracker, tracker_end, p, T) + for page_desc in pagelist: + n_elements = page_desc.num_elements + tracker_end = tracker + n_elements + self.read_pagedesc(res[tracker:tracker_end], page_desc, dtype_str, dtype) tracker = tracker_end - if dtype_byte <= 2: + if dtype_byte <= uproot.const.rntuple_col_type_to_num_dict["index32"]: res = numpy.insert(res, 0, 0) # for offsets return res @@ -335,30 +360,44 @@ def arrays( array_cache=None, ): entry_stop = entry_stop or self._length + clusters = self.cluster_summaries - if len(clusters) != 1: - raise (RuntimeError("Not implemented")) - # FIXME we assume cluster starts at entry 0, i.e only one cluster - L = clusters[0].num_entries + cluster_starts = numpy.array([c.num_first_entry for c in clusters]) + + start_cluster_idx = ( + numpy.searchsorted(cluster_starts, entry_start, side="right") - 1 + ) + stop_cluster_idx = numpy.searchsorted(cluster_starts, entry_stop, side="right") + cluster_num_entries = numpy.sum( + [c.num_entries for c in clusters[start_cluster_idx:stop_cluster_idx]] + ) form = self.to_akform().select_columns(filter_names) # only read columns mentioned in the awkward form target_cols = [] - D = {} + container_dict = {} _recursive_find(form, target_cols) for i, cr in enumerate(self.column_records): key = f"column-{i}" + dtype_byte = cr.type if key in target_cols: - content = self.read_col_page(i, L) - if cr.type == uproot.const.rntuple_role_union: + content = self.read_col_pages( + i, range(start_cluster_idx, stop_cluster_idx) + ) + if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]: kindex, tags = _split_switch_bits(content) - D[f"{key}-index"] = kindex - D[f"{key}-tags"] = tags + container_dict[f"{key}-index"] = kindex + container_dict[f"{key}-tags"] = tags else: # don't distinguish data and offsets - D[f"{key}-data"] = content - D[f"{key}-offsets"] = content - return ak._v2.from_buffers(form, L, Container(D))[entry_start:entry_stop] + container_dict[f"{key}-data"] = content + container_dict[f"{key}-offsets"] = content + cluster_offset = cluster_starts[start_cluster_idx] + entry_start -= cluster_offset + entry_stop -= cluster_offset + return ak._v2.from_buffers(form, cluster_num_entries, container_dict)[ + entry_start:entry_stop + ] # Supporting function and classes @@ -379,14 +418,6 @@ def _recursive_find(form, res): _recursive_find(form.content, res) -class Container: - def __init__(self, D): - self._dict = D - - def __getitem__(self, name): - return self._dict[name] - - # https://github.com/jblomer/root/blob/ntuple-binary-format-v1/tree/ntuple/v7/doc/specifications.md#page-list-envelope class PageDescription: def read(self, chunk, cursor, context):