Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ class RooJSONFactoryWSTool {
void importVariable(const RooFit::Detail::JSONNode &p);
void importDependants(const RooFit::Detail::JSONNode &n);

void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &p);
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n);
void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins);
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins);

void exportAllObjects(RooFit::Detail::JSONNode &n);

Expand Down
12 changes: 11 additions & 1 deletion roofit/hs3/src/JSONFactories_HistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,16 @@ void collectElements(RooArgSet &elems, RooAbsArg *arg)
}
}

bool allRooRealVar(const RooAbsCollection &list)
{
for (auto *var : list) {
if (!dynamic_cast<RooRealVar *>(var)) {
return false;
}
}
return true;
}

struct Sample {
std::string name;
std::vector<double> hist;
Expand Down Expand Up @@ -920,7 +930,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
addNormFactor(par, sample, ws);
} else if (auto hf = dynamic_cast<const RooHistFunc *>(e)) {
updateObservables(hf->dataHist());
} else if (auto phf = dynamic_cast<ParamHistFunc *>(e)) {
} else if (ParamHistFunc *phf = dynamic_cast<ParamHistFunc *>(e); phf && allRooRealVar(phf->paramList())) {
phfs.push_back(phf);
} else if (auto fip = dynamic_cast<RooStats::HistFactory::FlexibleInterpVar *>(e)) {
// some (modified) histfactory models have several instances of FlexibleInterpVar
Expand Down
112 changes: 112 additions & 0 deletions roofit/hs3/src/JSONFactories_RooFitCore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <RooAbsCachedPdf.h>
#include <RooAddPdf.h>
#include <RooAddModel.h>
#include <RooBinning.h>
#include <RooBinSamplingPdf.h>
#include <RooBinWidthFunction.h>
#include <RooCategory.h>
Expand All @@ -33,6 +34,7 @@
#include <RooLegacyExpPoly.h>
#include <RooLognormal.h>
#include <RooMultiVarGaussian.h>
#include <RooStats/HistFactory/ParamHistFunc.h>
#include <RooPoisson.h>
#include <RooPolynomial.h>
#include <RooPolyVar.h>
Expand Down Expand Up @@ -532,6 +534,71 @@ class RooMultiVarGaussianFactory : public RooFit::JSONIO::Importer {
}
};

class ParamHistFuncFactory : public RooFit::JSONIO::Importer {
public:
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
RooArgList varList = tool->requestArgList<RooRealVar>(p, "variables");
if (!p.has_child("axes")) {
std::stringstream ss;
ss << "No axes given in '" << name << "'"
<< ". Using default binning (uniform; nbins=100). If needed, export the Workspace to JSON with a newer "
<< "Root version that supports custom ParamHistFunc binnings(>=6.38.00)." << std::endl;
RooJSONFactoryWSTool::warning(ss.str());
tool->wsEmplace<ParamHistFunc>(name, varList, tool->requestArgList<RooAbsReal>(p, "parameters"));
return true;
}
tool->wsEmplace<ParamHistFunc>(name, readBinning(p, varList), tool->requestArgList<RooAbsReal>(p, "parameters"));
return true;
}

private:
RooArgList readBinning(const JSONNode &topNode, const RooArgList &varList) const
{
// Temporary map from variable name → RooRealVar
std::map<std::string, std::unique_ptr<RooRealVar>> varMap;

// Build variables from JSON
for (const JSONNode &node : topNode["axes"].children()) {
const std::string name = node["name"].val();
std::unique_ptr<RooRealVar> obs;

if (node.has_child("edges")) {
std::vector<double> edges;
for (const auto &bound : node["edges"].children()) {
edges.push_back(bound.val_double());
}
obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), edges.front(), edges.back());
RooBinning bins(obs->getMin(), obs->getMax());
for (auto b : edges)
bins.addBoundary(b);
obs->setBinning(bins);
} else {
obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), node["min"].val_double(),
node["max"].val_double());
obs->setBins(node["nbins"].val_int());
}

varMap[name] = std::move(obs);
}

// Now build the final list following the order in varList
RooArgList vars;
for (int i = 0; i < varList.getSize(); ++i) {
const auto *refVar = dynamic_cast<RooRealVar *>(varList.at(i));
if (!refVar)
continue;

auto it = varMap.find(refVar->GetName());
if (it != varMap.end()) {
vars.addOwned(std::move(it->second)); // preserve ownership
}
}
return vars;
}
};

///////////////////////////////////////////////////////////////////////////////////////////////////////
// specialized exporter implementations
///////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -696,6 +763,7 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter {
expr.ReplaceAll("TMath::Sin", "sin");
expr.ReplaceAll("TMath::Sqrt", "sqrt");
expr.ReplaceAll("TMath::Power", "pow");
expr.ReplaceAll("TMath::Erf", "erf");
}
};
template <class RooArg_t>
Expand Down Expand Up @@ -952,6 +1020,47 @@ class RooExtendPdfStreamer : public RooFit::JSONIO::Exporter {
}
};

class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override;
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
{
auto *pdf = static_cast<const ParamHistFunc *>(func);
elem["type"] << key();
RooJSONFactoryWSTool::fillSeq(elem["variables"], pdf->dataVars());
RooJSONFactoryWSTool::fillSeq(elem["parameters"], pdf->paramList());
writeBinningInfo(pdf, elem);
return true;
}

private:
void writeBinningInfo(const ParamHistFunc *pdf, JSONNode &elem) const
{
auto &observablesNode = elem["axes"].set_seq();
// axes have to be ordered to get consistent bin indices
for (auto *var : static_range_cast<RooRealVar *>(pdf->dataVars())) {
std::string name = var->GetName();
RooJSONFactoryWSTool::testValidName(name, false);
JSONNode &obsNode = observablesNode.append_child().set_map();
obsNode["name"] << name;
if (var->getBinning().isUniform()) {
obsNode["min"] << var->getMin();
obsNode["max"] << var->getMax();
obsNode["nbins"] << var->getBins();
} else {
auto &edges = obsNode["edges"];
edges.set_seq();
double val = var->getBinning().binLow(0);
edges.append_child() << val;
for (int i = 0; i < var->getBinning().numBins(); ++i) {
val = var->getBinning().binHigh(i);
edges.append_child() << val;
}
}
}
}
};

#define DEFINE_EXPORTER_KEY(class_name, name) \
std::string const &class_name::key() const \
{ \
Expand Down Expand Up @@ -989,6 +1098,7 @@ DEFINE_EXPORTER_KEY(RooRealIntegralStreamer, "integral");
DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative");
DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf");
DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf");
DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step");

///////////////////////////////////////////////////////////////////////////////////////////////////////
// instantiate all importers and exporters
Expand Down Expand Up @@ -1021,6 +1131,7 @@ STATIC_EXECUTE([]() {
registerImporter<RooDerivativeFactory>("derivative", false);
registerImporter<RooFFTConvPdfFactory>("fft_conv_pdf", false);
registerImporter<RooExtendPdfFactory>("extend_pdf", false);
registerImporter<ParamHistFuncFactory>("step", false);

registerExporter<RooAddPdfStreamer<RooAddPdf>>(RooAddPdf::Class(), false);
registerExporter<RooAddPdfStreamer<RooAddModel>>(RooAddModel::Class(), false);
Expand All @@ -1047,6 +1158,7 @@ STATIC_EXECUTE([]() {
registerExporter<RooDerivativeStreamer>(RooDerivative::Class(), false);
registerExporter<RooFFTConvPdfStreamer>(RooFFTConvPdf::Class(), false);
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
registerExporter<ParamHistFuncStreamer>(ParamHistFunc::Class(), false);
});

} // namespace
14 changes: 7 additions & 7 deletions roofit/hs3/src/RooFitHS3_wsexportkeys.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ auto RooFitHS3_wsexportkeys = R"({
"sigmaR": "sigma_R"
}
},
"RooEffProd": {
"type": "efficiency_product_pdf_dist",
"proxies": {
"pdf": "pdf",
"eff": "eff"
}
},
"RooGamma": {
"type": "gamma_dist",
"proxies": {
Expand All @@ -79,13 +86,6 @@ auto RooFitHS3_wsexportkeys = R"({
"sigma": "sigma"
}
},
"ParamHistFunc": {
"type": "step",
"proxies": {
"dataVars": "variables",
"paramSet": "parameters"
}
},
"RooLandau": {
"type": "landau_dist",
"proxies": {
Expand Down
14 changes: 7 additions & 7 deletions roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ auto RooFitHS3_wsfactoryexpressions = R"({
"coefficients"
]
},
"efficiency_product_pdf_dist": {
"class": "RooEffProd",
"arguments": [
"pdf",
"eff"
]
},
"gamma_dist": {
"class": "RooGamma",
"arguments": [
Expand Down Expand Up @@ -112,13 +119,6 @@ auto RooFitHS3_wsfactoryexpressions = R"({
"observables"
]
},
"step": {
"class": "ParamHistFunc",
"arguments": [
"variables",
"parameters"
]
},
"sum": {
"class": "RooAddition",
"arguments": [
Expand Down
30 changes: 16 additions & 14 deletions roofit/hs3/src/RooJSONFactoryWSTool.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl<RooAbsReal>(const std::string &obj
* @param node The JSONNode to which the variable will be exported.
* @return void
*/
void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, bool storeConstant, bool storeBins)
{
auto *cv = dynamic_cast<const RooConstVar *>(v);
auto *rrv = dynamic_cast<const RooRealVar *>(v);
Expand All @@ -984,10 +984,10 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
var["const"] << true;
} else if (rrv) {
var["value"] << rrv->getVal();
if (rrv->isConstant()) {
if (rrv->isConstant() && storeConstant) {
var["const"] << rrv->isConstant();
}
if (rrv->getBins() != 100) {
if (rrv->getBins() != 100 && storeBins) {
var["nbins"] << rrv->getBins();
}
_domains->readVariable(*rrv);
Expand All @@ -1004,12 +1004,12 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
* @param n The JSONNode to which the variables will be exported.
* @return void
*/
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n)
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n, bool storeConstant, bool storeBins)
{
// export a list of RooRealVar objects
n.set_seq();
for (RooAbsArg *arg : allElems) {
exportVariable(arg, n);
exportVariable(arg, n, storeConstant, storeBins);
}
}

Expand Down Expand Up @@ -1070,7 +1070,7 @@ void RooJSONFactoryWSTool::exportObject(RooAbsArg const &func, std::set<std::str
// categories are created by the respective RooSimultaneous, so we're skipping the export here
return;
} else if (dynamic_cast<RooRealVar const *>(&func) || dynamic_cast<RooConstVar const *>(&func)) {
exportVariable(&func, *_varsNode);
exportVariable(&func, *_varsNode, true, false);
return;
}

Expand Down Expand Up @@ -1554,18 +1554,14 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)

// this really is an unbinned dataset
output["type"] << "unbinned";
exportVariables(variables, output["axes"]);
exportVariables(variables, output["axes"], false, true);
auto &coords = output["entries"].set_seq();
std::vector<double> weightVals;
bool hasNonUnityWeights = false;
for (int i = 0; i < data.numEntries(); ++i) {
data.get(i);
coords.append_child().fill_seq(variables, [](auto x) { return static_cast<RooRealVar *>(x)->getVal(); });
std::string datasetName = data.GetName();
/*if (datasetName.find("combData_ZvvH126.5") != std::string::npos) {
file << dynamic_cast<RooAbsReal *>(data.get(i)->find("atlas_invMass_PttEtaConvVBFCat1"))->getVal() <<
std::endl;
}*/
if (data.isWeighted()) {
weightVals.push_back(data.weight());
if (data.weight() != 1.)
Expand All @@ -1575,7 +1571,6 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
if (data.isWeighted() && hasNonUnityWeights) {
output["weights"].fill_seq(weightVals);
}
// file.close();
}

/**
Expand Down Expand Up @@ -1960,7 +1955,8 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n)
snapshotSorted.sort();
std::string name(snsh->GetName());
if (name != "default_values") {
this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"]);
this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"], true,
false);
}
}
_varsNode = nullptr;
Expand Down Expand Up @@ -2240,8 +2236,14 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
combineDatasets(*_rootnodeInput, datasets);

for (auto const &d : datasets) {
if (d)
if (d) {
_workspace.import(*d);
for (auto const &obs : *d->get()) {
if (auto *rrv = dynamic_cast<RooRealVar *>(obs)) {
_workspace.var(rrv->GetName())->setBinning(rrv->getBinning());
}
}
}
}

_rootnodeInput = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ inline RooFit::Detail::JSONNode &operator<<(RooFit::Detail::JSONNode &n, std::sp
return n;
}

inline RooFit::Detail::JSONNode &operator<<(RooFit::Detail::JSONNode &n, std::span<const int> v)
{
n.fill_seq(v);
return n;
}

template <class Key, class T, class Hash, class KeyEqual, class Allocator>
RooFit::Detail::JSONNode &
operator<<(RooFit::Detail::JSONNode &n, const std::unordered_map<Key, T, Hash, KeyEqual, Allocator> &m)
Expand Down
Loading