Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Missing axis_wrap_if_negative Record method in both v1 and v2 #1565

Merged
merged 3 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/awkward/array/Record.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ namespace awkward {
const ContentPtr
shallow_simplify() const override;

const int64_t
axis_wrap_if_negative(int64_t axis) const;

const ContentPtr
num(int64_t axis, int64_t depth) const override;

Expand Down
7 changes: 7 additions & 0 deletions src/awkward/_v2/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def branch_depth(self):
branch, depth = self._array.branch_depth
return branch, depth - 1

def axis_wrap_if_negative(self, axis):
if axis == 0:
raise ak._v2._util.error(
np.AxisError("Record type at axis=0 is a scalar, not an array")
)
return self._array.axis_wrap_if_negative(axis)

def __getitem__(self, where):
with ak._v2._util.SlicingErrorContext(self, where):
return self._getitem(where)
Expand Down
9 changes: 9 additions & 0 deletions src/libawkward/array/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ namespace awkward {
return shallow_copy();
}

const int64_t
Record::axis_wrap_if_negative(int64_t axis) const {
if (axis == 0) {
throw std::invalid_argument(
std::string("Record at axis=0 is a scalar, not an array") + FILENAME(__LINE__));
}
return array_.get()->axis_wrap_if_negative(axis);
}

const ContentPtr
Record::num(int64_t axis, int64_t depth) const {
int64_t posaxis = axis_wrap_if_negative(axis);
Expand Down
3 changes: 3 additions & 0 deletions src/python/content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3110,6 +3110,9 @@ make_Record(const py::handle& m, const std::string& name) {
.def("simplify", [](const ak::Record& self) {
return box(self.shallow_simplify());
})
.def("axis_wrap_if_negative",
&ak::Record::axis_wrap_if_negative,
py::arg("axis"))
.def("copy_to",
[](const ak::Record& self, const std::string& ptr_lib) -> py::object {
if (ptr_lib == "cpu") {
Expand Down
68 changes: 68 additions & 0 deletions tests/v2/test_1565-axis_wrap_if_negative_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import pytest # noqa: F401
import awkward as ak # noqa: F401
import numpy as np # noqa: F401

to_list = ak._v2.operations.to_list


def test_axis_wrap_if_negative_record_v2():
dict_cell_chain_field = {
"cell1": [
{
"locus": "TRA",
"v_call": "TRAV1",
"cdr3_length": 15,
}, # <-- represents one chain
{"locus": "TRB", "v_call": "TRBV1", "cdr3_length": 12},
],
"cell2": [{"locus": "TRA", "v_call": "TRAV1", "cdr3_length": 13}],
"cell3": [],
}

r = ak._v2.Record(dict_cell_chain_field)

with pytest.raises(np.AxisError):
r = ak._v2.operations.to_regular(r, 0)
r = ak._v2.operations.to_regular(r, 2)

list_cell_chain_field = [
[["TRA", "TRAV1", 15], ["TRB", "TRBV1", 12]],
[["TRA", "TRAV1", 13]],
[],
]

a = ak._v2.Array(list_cell_chain_field)
a = ak._v2.operations.to_regular(a, 0)
a = ak._v2.operations.to_regular(a, 2)


def test_axis_wrap_if_negative_record_v1():
dict_cell_chain_field = {
"cell1": [
{
"locus": "TRA",
"v_call": "TRAV1",
"cdr3_length": 15,
}, # <-- represents one chain
{"locus": "TRB", "v_call": "TRBV1", "cdr3_length": 12},
],
"cell2": [{"locus": "TRA", "v_call": "TRAV1", "cdr3_length": 13}],
"cell3": [],
}

r = ak.Record(dict_cell_chain_field)

r = ak.to_regular(r, 0)
r = ak.to_regular(r, 2)

list_cell_chain_field = [
[["TRA", "TRAV1", 15], ["TRB", "TRBV1", 12]],
[["TRA", "TRAV1", 13]],
[],
]

a = ak.Array(list_cell_chain_field)
a = ak.to_regular(a, 0)
a = ak.to_regular(a, 2)