Skip to content

Commit

Permalink
Add structures.root_graph_errors structure and ROOTGraphErrors elemen…
Browse files Browse the repository at this point in the history
…t (documented and tested). graph: parsed errors (private) now contain error indices.
  • Loading branch information
ynikitenko committed Apr 19, 2022
1 parent b2cb0c7 commit 71d169a
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 6 deletions.
13 changes: 13 additions & 0 deletions docs/source/structures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ Structures
graph
Graph

.. currentmodule:: lena.structures.root_graphs
.. autosummary::

root_graph_errors
ROOTGraphErrors

.. currentmodule:: lena.structures.elements
.. autosummary::

Expand Down Expand Up @@ -92,6 +98,13 @@ Graph
.. autoclass:: Graph
:members:

.. module:: lena.structures.root_graphs
.. autoclass:: root_graph_errors
:members:

.. autoclass:: ROOTGraphErrors
:members:

.. module:: lena.structures.elements
.. autoclass:: HistToGraph
:members:
Expand Down
3 changes: 3 additions & 0 deletions lena/structures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
unify_1_md
)
from .numpy_histogram import NumpyHistogram
from .root_graphs import root_graph_errors, ROOTGraphErrors
from .split_into_bins import (
IterateBins,
MapBins,
Expand All @@ -33,6 +34,8 @@
'Histogram',
'HistToGraph',
'NumpyHistogram',
'root_graph_errors',
'ROOTGraphErrors',
# hist functions
'check_edges_increasing',
'cell_to_string',
Expand Down
10 changes: 5 additions & 5 deletions lena/structures/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _parse_error_names(self, field_names):
for ind, field in enumerate(field_names):
if field.startswith("error_"):
in_error_fields = True
errors.append(field)
errors.append((field, ind))
else:
last_coord_ind = ind
if in_error_fields:
Expand All @@ -291,7 +291,7 @@ def _parse_error_names(self, field_names):
coords = set(field_names[:last_coord_ind+1])
parsed_errors = []

for err in errors:
for err, ind in errors:
err_coords = []
for coord in coords:
err_main = err[6:] # all after "error_"
Expand All @@ -308,7 +308,7 @@ def _parse_error_names(self, field_names):
" corresponding to several coordinates given"
)
# "error" may be redundant, but it is explicit.
parsed_errors.append(("error", err_coords[0], err_tail))
parsed_errors.append(("error", err_coords[0], err_tail, ind))

return parsed_errors

Expand Down Expand Up @@ -341,9 +341,9 @@ def _update_context(self, context):

xyz_coord_names = self._coord_names[:3]
for name, coord_name in zip(["x", "y", "z"], xyz_coord_names):
for ind, err in enumerate(self._parsed_error_names):
for err in self._parsed_error_names:
if err[1] == coord_name:
error_ind = dim + ind
error_ind = err[3]
if err[2]:
# add error suffix
error_name = name + "_" + err[2]
Expand Down
141 changes: 141 additions & 0 deletions lena/structures/root_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import array

import lena.context
import lena.flow


def _list_to_array(coords, type_code):
return array.array(type_code, (coord for coord in coords))


class root_graph_errors():
"""2-dimensional ROOT graph with errors.
This is an adapter for
`TGraphErrors <https://root.cern.ch/doc/master/classTGraphErrors.html>`_
and contains that graph as a field *root_graph*.
"""

def __init__(self, graph, type_code='d'):
"""*graph* is a Lena :class:`.graph`.
*type_code* is the basic numeric type of array values
(by default double). 'f' means floating values.
See Python module
`array <https://docs.python.org/3/library/array.html>`_
for more options.
"""

import ROOT

if graph.dim != 2:
raise lena.core.LenaValueError(
"graph dimension must be 2"
)

errors = graph._parsed_error_names
# this is not possible, because we forbid suffixes
# if len(errors) > 2:
# raise lena.core.LenaValueError(
# "graph contains too many error fields (maximum is 2)"
# )

x_coord = graph.field_names[0]
y_coord = graph.field_names[1]

x_error = ROOT.nullptr
y_error = ROOT.nullptr

error_x_ind = 0
error_y_ind = 0
for err in errors:
if err[2]:
# errors for unknown coordinates
# are forbidden in graph itself.
raise lena.core.LenaValueError(
"error suffixes are not allowed"
)
error_ind = err[3]
if err[1] == x_coord:
x_error = graph.coords[error_ind]
error_x_ind = error_ind
elif err[1] == y_coord:
y_error = graph.coords[error_ind]
error_y_ind = error_ind

self._error_x_ind = error_x_ind
self._error_y_ind = error_y_ind

n_points = len(graph.coords[0])

xs = _list_to_array(graph.coords[0], type_code)
ys = _list_to_array(graph.coords[1], type_code)
exs = ROOT.nullptr
eys = ROOT.nullptr
if x_error:
exs = _list_to_array(x_error, type_code)
if y_error:
eys = _list_to_array(y_error, type_code)

self.root_graph = ROOT.TGraphErrors(n_points, xs, ys, exs, eys)

def _arrays(self):
import ROOT
# not a class field, because it can't be pickled
rg = self.root_graph
arrays = [
# all these values are pointers,
# so they can't be pickled.
rg.GetX(),
rg.GetY(),
]
if self._error_x_ind:
arrays.append(rg.GetEX())
if self._error_y_ind:
arrays.append(rg.GetEY())
return arrays

def __eq__(self, other):
if not isinstance(other, root_graph_errors):
return False
# looks they can't be compared directly
# return self.root_graph == other.root_graph
# error indices are the same
if (self._error_x_ind != other._error_x_ind
or self._error_y_ind != other._error_y_ind):
return False
# pointwise comparison
return list(self) == list(other)

def __iter__(self):
npoints = self.root_graph.GetN()
for ind in range(npoints):
res = tuple((arr[ind] for arr in self._arrays()))
yield res

def __len__(self):
return self.root_graph.GetN()

def _update_context(self, context):
error_x_ind = self._error_x_ind
error_y_ind = self._error_y_ind
if error_x_ind:
lena.context.update_recursively(
context, "error.x.index", error_x_ind
)
if error_y_ind:
lena.context.update_recursively(
context, "error.y.index", error_y_ind
)


class ROOTGraphErrors():
"""Element to convert graphs to :class:`.root_graph_errors`."""

def __call__(self, value):
"""Convert data part of the value
(which must be a :class:`.graph`)
to :class:`.root_graph_errors`.
"""
graph, context = lena.flow.get_data_context(value)
return (root_graph_errors(graph), context)
2 changes: 1 addition & 1 deletion tests/structures/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_graph_error_fields():
# one error field works
gr1 = graph(copy.deepcopy([xs, ys, [1, 2]]),
field_names="x, y, error_x", scale=2)
assert gr1._parsed_error_names == [('error', 'x', '')]
assert gr1._parsed_error_names == [('error', 'x', '', 2)]

# wrong order of fields raises
with pytest.raises(lena.core.LenaValueError) as exc:
Expand Down
69 changes: 69 additions & 0 deletions tests/structures/test_root_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

import lena.core
from lena.structures import graph, root_graph_errors, ROOTGraphErrors


def test_root_graph_errors():
# no errors
gr0 = graph([[0, 1, 2], [2, 3, 4]])
rgr0 = root_graph_errors(gr0)
assert list(rgr0) == [(0, 2), (1, 3), (2, 4)]
assert len(rgr0) == 3
context0 = {}
rgr0._update_context(context0)
assert context0 == {}

# x error
gr1 = graph([[0, 1], [2, 3], [0.1, 0.1]], field_names="x,y,error_x")
rgr1 = root_graph_errors(gr1)
assert list(rgr1) == [(0, 2, 0.1), (1, 3, 0.1)]
assert len(rgr1) == 2
context1 = {}
rgr1._update_context(context1)
assert context1 == {"error": {"x": {"index": 2}}}

# x and y errors
gr2 = graph([[0, 1], [2, 3], [0.1, 0.1], [0.2, 0.2]],
field_names="x,y,error_x,error_y")
rgr2 = root_graph_errors(gr2)

assert list(rgr2) == [(0, 2, 0.1, 0.2), (1, 3, 0.1, 0.2)]
assert len(rgr2) == 2
context2 = {}
rgr2._update_context(context2)
assert context2 == {
"error": {
"x": {"index": 2},
"y": {"index": 3}
}
}

# test comparison
# different errors give different graphs
coords = [[0, 1], [2, 3], [0.1, 0.1]]
grx = graph(coords, field_names="x,y,error_x")
gry = graph(coords, field_names="x,y,error_y")
rgrx = root_graph_errors(grx)
rgry = root_graph_errors(gry)
# no, no idea why this is not covered in Python 2...
assert rgrx != rgry
# not root_graph_errors returns False
assert rgrx != coords

# error suffixes are not allowed
gre = graph(coords, field_names="x,y,error_x_low")
with pytest.raises(lena.core.LenaValueError):
root_graph_errors(gre)

# only 2-dimensional graphs are allowed
gr3d = graph(coords, field_names="x,y,z")
with pytest.raises(lena.core.LenaValueError):
root_graph_errors(gr3d)


def test_ROOTGraphErrors():
el = ROOTGraphErrors()
gr = graph([[0, 1], [1, 2]])
rgr = root_graph_errors(gr)
assert el(gr) == (rgr, {})

0 comments on commit 71d169a

Please sign in to comment.