-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
MinuitFcnGrad.h
115 lines (88 loc) · 4.08 KB
/
MinuitFcnGrad.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*
* Project: RooFit
* Authors:
* PB, Patrick Bos, Netherlands eScience Center, p.bos@esciencecenter.nl
*
* Copyright (c) 2021, CERN
*
* Redistribution and use in source and binary forms,
* with or without modification, are permitted according to the terms
* listed in LICENSE (http://roofit.sourceforge.net/license.txt)
*/
#ifndef ROOT_ROOFIT_TESTSTATISTICS_MinuitFcnGrad
#define ROOT_ROOFIT_TESTSTATISTICS_MinuitFcnGrad
#include "RooArgList.h"
#include "RooRealVar.h"
#include <RooFit/TestStatistics/RooAbsL.h>
#include <RooFit/TestStatistics/LikelihoodWrapper.h>
#include <RooFit/TestStatistics/LikelihoodGradientWrapper.h>
#include "RooAbsMinimizerFcn.h"
#include <Fit/ParameterSettings.h>
#include <Fit/Fitter.h>
#include "Math/IFunction.h" // ROOT::Math::IMultiGradFunction
// forward declaration
class RooAbsReal;
class RooMinimizer;
namespace RooFit {
namespace TestStatistics {
/// For communication with wrappers, an instance of this struct must be shared between them and MinuitFcnGrad. It keeps
/// track of what has been evaluated for the current parameter set provided by Minuit.
struct WrapperCalculationCleanFlags {
// indicate whether that part has been calculated since the last parameter update
bool likelihood = false;
bool gradient = false;
void set_all(bool value)
{
likelihood = value;
gradient = value;
}
};
class MinuitFcnGrad : public ROOT::Math::IMultiGradFunction, public RooAbsMinimizerFcn {
public:
enum class LikelihoodMode { serial, multiprocess };
enum class LikelihoodGradientMode { multiprocess };
MinuitFcnGrad(const std::shared_ptr<RooFit::TestStatistics::RooAbsL> &_likelihood, RooMinimizer *context,
std::vector<ROOT::Fit::ParameterSettings> ¶meters,
LikelihoodMode likelihoodMode,
LikelihoodGradientMode likelihoodGradientMode, bool verbose = false);
inline ROOT::Math::IMultiGradFunction *Clone() const override { return new MinuitFcnGrad(*this); }
/// Overridden from RooAbsMinimizerFcn to include gradient strategy synchronization.
Bool_t Synchronize(std::vector<ROOT::Fit::ParameterSettings> ¶meter_settings, Bool_t optConst,
Bool_t verbose = kFALSE) override;
// used inside Minuit:
inline bool returnsInMinuit2ParameterSpace() const override { return gradient->usesMinuitInternalValues(); }
inline void setOptimizeConstOnFunction(RooAbsArg::ConstOpCode opcode, Bool_t doAlsoTrackingOpt) override
{
likelihood->constOptimizeTestStatistic(opcode, doAlsoTrackingOpt);
}
private:
/// IMultiGradFunction override necessary for Minuit
double DoEval(const double *x) const override;
public:
/// IMultiGradFunction overrides necessary for Minuit
void Gradient(const double *x, double *grad) const override;
void GradientWithPrevResult(const double *x, double *grad, double *previous_grad, double *previous_g2,
double *previous_gstep) const override;
/// Part of IMultiGradFunction interface, used widely both in Minuit and in RooFit.
inline unsigned int NDim() const override { return _nDim; }
inline std::string getFunctionName() const override { return likelihood->GetName(); }
inline std::string getFunctionTitle() const override { return likelihood->GetTitle(); }
inline void setOffsetting(Bool_t flag) override { likelihood->enableOffsetting(flag); }
private:
/// This override should not be used in this class, so it throws.
double DoDerivative(const double *x, unsigned int icoord) const override;
bool syncParameterValuesFromMinuitCalls(const double *x, bool minuit_internal) const;
// members
std::shared_ptr<LikelihoodWrapper> likelihood;
std::shared_ptr<LikelihoodGradientWrapper> gradient;
public:
mutable std::shared_ptr<WrapperCalculationCleanFlags> calculation_is_clean;
private:
mutable std::vector<double> minuit_internal_x_;
mutable std::vector<double> minuit_external_x_;
public:
mutable bool minuit_internal_roofit_x_mismatch_ = false;
};
} // namespace TestStatistics
} // namespace RooFit
#endif // ROOT_ROOFIT_TESTSTATISTICS_MinuitFcnGrad