-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add structures.root_graph_errors structure and ROOTGraphErrors elemen…
…t (documented and tested). graph: parsed errors (private) now contain error indices.
- Loading branch information
1 parent
b2cb0c7
commit 71d169a
Showing
6 changed files
with
232 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import array | ||
|
||
import lena.context | ||
import lena.flow | ||
|
||
|
||
def _list_to_array(coords, type_code): | ||
return array.array(type_code, (coord for coord in coords)) | ||
|
||
|
||
class root_graph_errors(): | ||
"""2-dimensional ROOT graph with errors. | ||
This is an adapter for | ||
`TGraphErrors <https://root.cern.ch/doc/master/classTGraphErrors.html>`_ | ||
and contains that graph as a field *root_graph*. | ||
""" | ||
|
||
def __init__(self, graph, type_code='d'): | ||
"""*graph* is a Lena :class:`.graph`. | ||
*type_code* is the basic numeric type of array values | ||
(by default double). 'f' means floating values. | ||
See Python module | ||
`array <https://docs.python.org/3/library/array.html>`_ | ||
for more options. | ||
""" | ||
|
||
import ROOT | ||
|
||
if graph.dim != 2: | ||
raise lena.core.LenaValueError( | ||
"graph dimension must be 2" | ||
) | ||
|
||
errors = graph._parsed_error_names | ||
# this is not possible, because we forbid suffixes | ||
# if len(errors) > 2: | ||
# raise lena.core.LenaValueError( | ||
# "graph contains too many error fields (maximum is 2)" | ||
# ) | ||
|
||
x_coord = graph.field_names[0] | ||
y_coord = graph.field_names[1] | ||
|
||
x_error = ROOT.nullptr | ||
y_error = ROOT.nullptr | ||
|
||
error_x_ind = 0 | ||
error_y_ind = 0 | ||
for err in errors: | ||
if err[2]: | ||
# errors for unknown coordinates | ||
# are forbidden in graph itself. | ||
raise lena.core.LenaValueError( | ||
"error suffixes are not allowed" | ||
) | ||
error_ind = err[3] | ||
if err[1] == x_coord: | ||
x_error = graph.coords[error_ind] | ||
error_x_ind = error_ind | ||
elif err[1] == y_coord: | ||
y_error = graph.coords[error_ind] | ||
error_y_ind = error_ind | ||
|
||
self._error_x_ind = error_x_ind | ||
self._error_y_ind = error_y_ind | ||
|
||
n_points = len(graph.coords[0]) | ||
|
||
xs = _list_to_array(graph.coords[0], type_code) | ||
ys = _list_to_array(graph.coords[1], type_code) | ||
exs = ROOT.nullptr | ||
eys = ROOT.nullptr | ||
if x_error: | ||
exs = _list_to_array(x_error, type_code) | ||
if y_error: | ||
eys = _list_to_array(y_error, type_code) | ||
|
||
self.root_graph = ROOT.TGraphErrors(n_points, xs, ys, exs, eys) | ||
|
||
def _arrays(self): | ||
import ROOT | ||
# not a class field, because it can't be pickled | ||
rg = self.root_graph | ||
arrays = [ | ||
# all these values are pointers, | ||
# so they can't be pickled. | ||
rg.GetX(), | ||
rg.GetY(), | ||
] | ||
if self._error_x_ind: | ||
arrays.append(rg.GetEX()) | ||
if self._error_y_ind: | ||
arrays.append(rg.GetEY()) | ||
return arrays | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, root_graph_errors): | ||
return False | ||
# looks they can't be compared directly | ||
# return self.root_graph == other.root_graph | ||
# error indices are the same | ||
if (self._error_x_ind != other._error_x_ind | ||
or self._error_y_ind != other._error_y_ind): | ||
return False | ||
# pointwise comparison | ||
return list(self) == list(other) | ||
|
||
def __iter__(self): | ||
npoints = self.root_graph.GetN() | ||
for ind in range(npoints): | ||
res = tuple((arr[ind] for arr in self._arrays())) | ||
yield res | ||
|
||
def __len__(self): | ||
return self.root_graph.GetN() | ||
|
||
def _update_context(self, context): | ||
error_x_ind = self._error_x_ind | ||
error_y_ind = self._error_y_ind | ||
if error_x_ind: | ||
lena.context.update_recursively( | ||
context, "error.x.index", error_x_ind | ||
) | ||
if error_y_ind: | ||
lena.context.update_recursively( | ||
context, "error.y.index", error_y_ind | ||
) | ||
|
||
|
||
class ROOTGraphErrors(): | ||
"""Element to convert graphs to :class:`.root_graph_errors`.""" | ||
|
||
def __call__(self, value): | ||
"""Convert data part of the value | ||
(which must be a :class:`.graph`) | ||
to :class:`.root_graph_errors`. | ||
""" | ||
graph, context = lena.flow.get_data_context(value) | ||
return (root_graph_errors(graph), context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
|
||
import lena.core | ||
from lena.structures import graph, root_graph_errors, ROOTGraphErrors | ||
|
||
|
||
def test_root_graph_errors(): | ||
# no errors | ||
gr0 = graph([[0, 1, 2], [2, 3, 4]]) | ||
rgr0 = root_graph_errors(gr0) | ||
assert list(rgr0) == [(0, 2), (1, 3), (2, 4)] | ||
assert len(rgr0) == 3 | ||
context0 = {} | ||
rgr0._update_context(context0) | ||
assert context0 == {} | ||
|
||
# x error | ||
gr1 = graph([[0, 1], [2, 3], [0.1, 0.1]], field_names="x,y,error_x") | ||
rgr1 = root_graph_errors(gr1) | ||
assert list(rgr1) == [(0, 2, 0.1), (1, 3, 0.1)] | ||
assert len(rgr1) == 2 | ||
context1 = {} | ||
rgr1._update_context(context1) | ||
assert context1 == {"error": {"x": {"index": 2}}} | ||
|
||
# x and y errors | ||
gr2 = graph([[0, 1], [2, 3], [0.1, 0.1], [0.2, 0.2]], | ||
field_names="x,y,error_x,error_y") | ||
rgr2 = root_graph_errors(gr2) | ||
|
||
assert list(rgr2) == [(0, 2, 0.1, 0.2), (1, 3, 0.1, 0.2)] | ||
assert len(rgr2) == 2 | ||
context2 = {} | ||
rgr2._update_context(context2) | ||
assert context2 == { | ||
"error": { | ||
"x": {"index": 2}, | ||
"y": {"index": 3} | ||
} | ||
} | ||
|
||
# test comparison | ||
# different errors give different graphs | ||
coords = [[0, 1], [2, 3], [0.1, 0.1]] | ||
grx = graph(coords, field_names="x,y,error_x") | ||
gry = graph(coords, field_names="x,y,error_y") | ||
rgrx = root_graph_errors(grx) | ||
rgry = root_graph_errors(gry) | ||
# no, no idea why this is not covered in Python 2... | ||
assert rgrx != rgry | ||
# not root_graph_errors returns False | ||
assert rgrx != coords | ||
|
||
# error suffixes are not allowed | ||
gre = graph(coords, field_names="x,y,error_x_low") | ||
with pytest.raises(lena.core.LenaValueError): | ||
root_graph_errors(gre) | ||
|
||
# only 2-dimensional graphs are allowed | ||
gr3d = graph(coords, field_names="x,y,z") | ||
with pytest.raises(lena.core.LenaValueError): | ||
root_graph_errors(gr3d) | ||
|
||
|
||
def test_ROOTGraphErrors(): | ||
el = ROOTGraphErrors() | ||
gr = graph([[0, 1], [1, 2]]) | ||
rgr = root_graph_errors(gr) | ||
assert el(gr) == (rgr, {}) |