Skip to content

Commit

Permalink
attempt to fix segfault (#36)
Browse files Browse the repository at this point in the history
GenCrossSection has an unsafe API which requires one to call certain functions in the right order. There is also no protection against accessing vectors out-of-bounds. The patches applied here fix the most pressing issues in the Python layer.
  • Loading branch information
HDembinski committed Sep 1, 2022
1 parent fcb8da5 commit 3483d62
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 47 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
name: Coverage

on:
pull_request:
paths-ignore:
- 'docs/**'
- '*.rst'
- '*.md'
push:
branch:
- 'main'
branch: [main]
paths-ignore:
- 'docs/**'
- '*.rst'
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ dist
*.so
*.egg-info
tests/__pycache__
.eggs
.cache
.compiler_support_cache
.flag_filter_cache
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
requires = [
"setuptools>=46.4",
"setuptools_scm[toml]>=6.2",
"cmake>=3.13"
"cmake>=3.13",
"wheel"
]
build-backend = "setuptools.build_meta"

Expand Down
2 changes: 1 addition & 1 deletion src/attribute_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ AttributePtr attribute_from_python(py::object obj) {
if (!result) {
using Types = mp_list<mp_list<BoolAttribute, py::bool_, bool>,
mp_list<IntAttribute, py::int_, int>,
mp_list<FloatAttribute, py::float_, double>,
mp_list<DoubleAttribute, py::float_, double>,
mp_list<StringAttribute, py::str, std::string>>;

mp_for_each<Types>([&](auto t) {
Expand Down
48 changes: 14 additions & 34 deletions src/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,25 +260,8 @@ inline std::ostream& repr(std::ostream& os, const HepMC3::GenEvent& x) {
return os;
}

inline int gencrosssection_validate_index(GenCrossSection& cs, py::object obj) {
auto idx = py::cast<int>(obj);
const auto size =
cs.event() ? (std::max)(cs.event()->weights().size(), static_cast<std::size_t>(1))
: 1u;
if (idx < 0) idx += size;
if (idx < 0 || idx >= size) throw py::index_error();
return idx;
}

inline std::string gencrosssection_validate_name(GenCrossSection& cs, py::object obj) {
auto name = py::cast<std::string>(obj);
if (cs.event() && cs.event()->run_info()) {
const auto& wnames = cs.event()->run_info()->weight_names();
if (std::find(wnames.begin(), wnames.end(), name) != wnames.end()) return name;
}
throw py::key_error(name);
return {};
}
int crosssection_safe_index(GenCrossSection& cs, py::object obj);
std::string crosssection_safe_name(GenCrossSection& cs, py::object obj);

void from_hepevt(GenEvent& event, int event_number, py::array_t<double> px,
py::array_t<double> py, py::array_t<double> pz, py::array_t<double> en,
Expand Down Expand Up @@ -502,22 +485,16 @@ PYBIND11_MODULE(_core, m) {

py::class_<GenCrossSection, GenCrossSectionPtr, Attribute>(m, "GenCrossSection",
DOC(GenCrossSection))
.def(py::init([](double cs, double cse, long acc, long att) {
auto p = std::make_shared<GenCrossSection>();
p->set_cross_section(cs, cse, acc, att);
return p;
}),
"cross_section"_a, "cross_section_error"_a, "accepted_events"_a,
"attempted_events"_a)
.def(py::init<>())
.def(
"xsec",
[](GenCrossSection& self, py::object obj) {
// need to do checks for invalid indices because they are missing in C++
if (py::isinstance<py::int_>(obj)) {
auto idx = gencrosssection_validate_index(self, obj);
auto idx = crosssection_safe_index(self, obj);
return self.xsec(idx);
} else if (py::isinstance<py::str>(obj)) {
auto name = gencrosssection_validate_name(self, obj);
auto name = crosssection_safe_name(self, obj);
return self.xsec(name);
} else
throw py::type_error("int or str required");
Expand All @@ -528,10 +505,10 @@ PYBIND11_MODULE(_core, m) {
"xsec_err",
[](GenCrossSection& self, py::object obj) {
if (py::isinstance<py::int_>(obj)) {
auto idx = gencrosssection_validate_index(self, obj);
auto idx = crosssection_safe_index(self, obj);
return self.xsec_err(idx);
} else if (py::isinstance<py::str>(obj)) {
auto name = gencrosssection_validate_name(self, obj);
auto name = crosssection_safe_name(self, obj);
return self.xsec_err(name);
} else
throw py::type_error("int or str required");
Expand All @@ -542,10 +519,10 @@ PYBIND11_MODULE(_core, m) {
"set_xsec",
[](GenCrossSection& self, py::object obj, double value) {
if (py::isinstance<py::int_>(obj)) {
auto idx = gencrosssection_validate_index(self, obj);
auto idx = crosssection_safe_index(self, obj);
self.set_xsec(idx, value);
} else if (py::isinstance<py::str>(obj)) {
auto name = gencrosssection_validate_name(self, obj);
auto name = crosssection_safe_name(self, obj);
self.set_xsec(name, value);
} else
throw py::type_error("int or str required");
Expand All @@ -555,15 +532,18 @@ PYBIND11_MODULE(_core, m) {
"set_xsec_err",
[](GenCrossSection& self, py::object obj, double value) {
if (py::isinstance<py::int_>(obj)) {
auto idx = gencrosssection_validate_index(self, obj);
auto idx = crosssection_safe_index(self, obj);
self.set_xsec_err(idx, value);
} else if (py::isinstance<py::str>(obj)) {
auto name = gencrosssection_validate_name(self, obj);
auto name = crosssection_safe_name(self, obj);
self.set_xsec_err(name, value);
} else
throw py::type_error("int or str required");
},
"index_or_name"_a, "value"_a, DOC(GenCrossSection.set_xsec_err))
.def("set_cross_section", &GenCrossSection::set_cross_section, "cross_section"_a,
"cross_section_error"_a, "accepted_events"_a = -1, "attempted_events"_a = -1,
DOC(GenCrossSection.set_cross_section))
// clang-format off
PROP2(accepted_events, GenCrossSection)
PROP2(attempted_events, GenCrossSection)
Expand Down
49 changes: 49 additions & 0 deletions src/crosssection_patches.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "pybind.hpp"
#include <HepMC3/GenCrossSection.h>
#include <HepMC3/GenEvent.h>
#include <accessor/accessor.hpp>
#include <cassert>
#include <vector>

// To resize cross-section vectors, we use the legal crowbar
// to access the private attribute map of GenCrossSection
MEMBER_ACCESSOR(CS1, HepMC3::GenCrossSection, cross_sections, std::vector<double>)
MEMBER_ACCESSOR(CS2, HepMC3::GenCrossSection, cross_section_errors, std::vector<double>)

namespace HepMC3 {

void crosssection_maybe_increase_vector(GenCrossSection& cs, unsigned size) {
auto cs1 = accessor::accessMember<CS1>(cs);
auto cs2 = accessor::accessMember<CS2>(cs);
assert(cs1.get().size() == cs2.get().size());
if (size > cs1.get().size()) {
cs1.get().resize(size, 0);
cs2.get().resize(size, 0);
}
}

int crosssection_safe_index(GenCrossSection& cs, py::object obj) {
auto idx = py::cast<int>(obj);
const auto size =
cs.event() ? (std::max)(cs.event()->weights().size(), static_cast<std::size_t>(1))
: 1u;
if (idx < 0) idx += size;
if (idx < 0 || static_cast<unsigned>(idx) >= size) throw py::index_error();
crosssection_maybe_increase_vector(cs, size);
return idx;
}

std::string crosssection_safe_name(GenCrossSection& cs, py::object obj) {
auto name = py::cast<std::string>(obj);
if (cs.event() && cs.event()->run_info()) {
const auto size =
(std::max)(cs.event()->weights().size(), static_cast<std::size_t>(1));
crosssection_maybe_increase_vector(cs, size);
const auto& wnames = cs.event()->run_info()->weight_names();
if (std::find(wnames.begin(), wnames.end(), name) != wnames.end()) return name;
}
throw py::key_error(name);
return {};
}

} // namespace HepMC3
12 changes: 8 additions & 4 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def test_GenHeavyIon():


def test_GenCrossSection():
cs = hep.GenCrossSection(1.2, 0.2, 3, 10)
cs = hep.GenCrossSection()
cs.set_cross_section(1.2, 0.2, 3, 10)
assert cs.event is None
cs.set_xsec(0, 1.2)
with pytest.raises(KeyError):
Expand All @@ -119,9 +120,10 @@ def test_GenCrossSection():
with pytest.raises(IndexError):
assert cs.xsec(1)

ri = hep.GenRunInfo()
ri.weight_names = ("foo", "bar") # optional
evt = hep.GenEvent()
evt.run_info = hep.GenRunInfo()
evt.run_info.weight_names = ("foo", "bar") # optional
evt.run_info = ri
evt.weights = [1.0, 2.0]
evt.cross_section = cs
assert evt.cross_section.event is evt
Expand Down Expand Up @@ -219,14 +221,16 @@ def test_attributes_2(evt):
[1, 2],
["foo", "bar"],
[True, False],
hep.GenCrossSection(1.2, 0.2, 3, 10),
hep.GenCrossSection(),
hep.GenHeavyIon(),
hep.GenPdfInfo(),
hep.HEPRUPAttribute(),
hep.HEPEUPAttribute(),
],
)
def test_attributes_3(evt, value):
if isinstance(value, hep.GenCrossSection):
value.set_cross_section(1.2, 0.2, 3, 10)
p1 = evt.particles[0]
assert p1.id == 1
assert p1.attributes == {}
Expand Down

0 comments on commit 3483d62

Please sign in to comment.