/
data_comparison.go
68 lines (63 loc) · 1.97 KB
/
data_comparison.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package inference
import (
"github.com/umbralcalc/stochadex/pkg/simulator"
"gonum.org/v1/gonum/mat"
)
// LikelihoodDistribution is the interface that must be implemented in
// order to create a likelihood that connects derived statistics from the
// probabilistic reweighting to observed actual data values.
type LikelihoodDistribution interface {
Configure(partitionIndex int, settings *simulator.Settings)
EvaluateLogLike(mean *mat.VecDense, covariance mat.Symmetric, data []float64) float64
GenerateNewSamples(mean *mat.VecDense, covariance mat.Symmetric) []float64
}
// DataComparisonIteration allows for any data linking log-likelihood to be used
// as a comparison distribution between data values, a mean vector and covariance
// matrix.
type DataComparisonIteration struct {
Likelihood LikelihoodDistribution
burnInSteps int
cumulative bool
}
func (d *DataComparisonIteration) Configure(
partitionIndex int,
settings *simulator.Settings,
) {
d.cumulative = false
c, ok := settings.OtherParams[partitionIndex].IntParams["cumulative"]
if ok {
d.cumulative = c[0] == 1
}
d.burnInSteps = int(
settings.OtherParams[partitionIndex].IntParams["burn_in_steps"][0],
)
d.Likelihood.Configure(partitionIndex, settings)
}
func (d *DataComparisonIteration) Iterate(
params *simulator.OtherParams,
partitionIndex int,
stateHistories []*simulator.StateHistory,
timestepsHistory *simulator.CumulativeTimestepsHistory,
) []float64 {
if timestepsHistory.CurrentStepNumber < d.burnInSteps {
return []float64{stateHistories[partitionIndex].Values.At(0, 0)}
}
dims := len(params.FloatParams["mean"])
var covMat *mat.SymDense
cVals, ok := params.FloatParams["covariance_matrix"]
if ok {
covMat = mat.NewSymDense(dims, cVals)
}
like := d.Likelihood.EvaluateLogLike(
mat.NewVecDense(
dims,
params.FloatParams["mean"],
),
covMat,
params.FloatParams["latest_data_values"],
)
if d.cumulative {
like += stateHistories[partitionIndex].Values.At(0, 0)
}
return []float64{like}
}