Skip to content

Commit

Permalink
Add conversion of TH1F to PlottableHistogram
Browse files Browse the repository at this point in the history
  • Loading branch information
pieterdavid committed May 27, 2021
1 parent 8a49256 commit 7645233
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions src/uhi/numpy_plottable.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,123 @@ def variances(self) -> Optional[np.ndarray]:
_: PlottableHistogram = cast(NumPyPlottableHistogram, None)


def _roottarray_asnumpy(
tarr: Any, shape: Optional[Tuple[int, ...]] = None
) -> np.ndarray:
llv = tarr.GetArray()
arr: np.ndarray = np.frombuffer(llv, dtype=llv.typecode, count=tarr.GetSize())
if shape is not None:
return np.reshape(arr, shape, order="F")
else:
return arr


class ROOTAxis:
def __init__(self, tAxis: Any) -> None:
self.tAx = tAxis

def __len__(self) -> int:
return self.tAx.GetNbins() # type: ignore

def __getitem__(self, index: int) -> Any:
pass

def __eq__(self, other: Any) -> bool:
if not isinstance(other, ROOTAxis):
return NotImplemented
return len(self) == len(other) and all(
aEdges == bEdges for aEdges, bEdges in zip(self, other)
)

def __iter__(self) -> Union[Iterator[Tuple[float, float]], Iterator[str]]:
pass

@staticmethod
def create(tAx: Any) -> Union["DiscreteROOTAxis", "ContinuousROOTAxis"]:
if all(tAx.GetBinLabel(i + 1) for i in range(tAx.GetNbins())):
return DiscreteROOTAxis(tAx)
else:
return ContinuousROOTAxis(tAx)


class ContinuousROOTAxis(ROOTAxis):
@property
def traits(self) -> PlottableTraits:
return Traits(circular=False, discrete=False)

def __getitem__(self, index: int) -> Tuple[float, float]:
return (self.tAx.GetBinLowEdge(index + 1), self.tAx.GetBinUpEdge(index + 1))

def __iter__(self) -> Iterator[Tuple[float, float]]:
for i in range(len(self)):
yield self[i]


class DiscreteROOTAxis(ROOTAxis):
@property
def traits(self) -> PlottableTraits:
return Traits(circular=False, discrete=True)

def __getitem__(self, index: int) -> str:
return self.tAx.GetBinLabel(index + 1) # type: ignore

def __iter__(self) -> Iterator[str]:
for i in range(len(self)):
yield self[i]


class ROOTPlottableHistogram:
def __init__(self, thist: Any) -> None:
self.thist: Any = thist
nDim = thist.GetDimension()
self._shape: Tuple[int, ...] = tuple(
getattr(thist, f"GetNbins{ax}")() + 2 for ax in "XYZ"[:nDim]
)
self.axes: Tuple[Union[ContinuousROOTAxis, DiscreteROOTAxis], ...] = tuple(
ROOTAxis.create(getattr(thist, f"Get{ax}axis")()) for ax in "XYZ"[:nDim]
)

@property
def hasWeights(self) -> bool:
return bool(self.thist.GetSumw2() and self.thist.GetSumw2N())

@property
def kind(self) -> str:
return Kind.COUNT

def values(self) -> np.ndarray:
return _roottarray_asnumpy(self.thist, shape=self._shape)[ # type: ignore
tuple([slice(1, -1)] * len(self._shape))
]

def variances(self) -> np.ndarray:
if self.hasWeights:
return _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[ # type: ignore
tuple([slice(1, -1)] * len(self._shape))
]
else:
return self.values()

def counts(self) -> np.ndarray:
if self.hasWeights:
sumw = self.values()
return np.divide( # type: ignore
sumw ** 2,
self.variances(),
out=np.zeros_like(sumw, dtype=np.float64),
where=sumw != 0,
)
else:
return self.values()


if TYPE_CHECKING:
# Verify that the above class is a valid PlottableHistogram
_axis = cast(ContinuousROOTAxis, None)
_axis2: PlottableAxisGeneric[str] = cast(DiscreteROOTAxis, None)
_ = cast(ROOTPlottableHistogram, None)


def ensure_plottable_histogram(hist: Any) -> PlottableHistogram:
"""
Ensure a histogram follows the PlottableHistogram Protocol.
Expand Down Expand Up @@ -206,5 +323,8 @@ def ensure_plottable_histogram(hist: Any) -> PlottableHistogram:
# Standard tuple
return NumPyPlottableHistogram(*(np.asarray(h) for h in hist))

elif hasattr(hist, "InheritsFrom") and hist.InheritsFrom("TH1"):
return ROOTPlottableHistogram(hist)

else:
raise TypeError(f"Can't be used on this type of object: {hist!r}")

0 comments on commit 7645233

Please sign in to comment.