Permalink
Show file tree
Hide file tree
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
feat: adaptive integration of density functions using a binary search…
… approach that tries to achieve good resolution around the maximum (#486) * transfer adaptive integration from varlociraptor * make ln_integrate func public * import logprob from crate * add example/testcase Co-authored-by: Johannes Köster <johannes.koester@tu-dortmund.de>
- Loading branch information
1 parent
93289d5
commit 207b76f
Showing
2 changed files
with
147 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
// Copyright 2021-2022 Johannes Köster. | ||
// Licensed under the MIT license (http://opensource.org/licenses/MIT) | ||
// This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
use std::cmp; | ||
use std::collections::HashMap; | ||
use std::convert::Into; | ||
use std::hash::Hash; | ||
use std::{ | ||
fmt::Debug, | ||
ops::{Add, Div, Mul, Sub}, | ||
}; | ||
|
||
use crate::stats::probs::LogProb; | ||
use itertools::Itertools; | ||
use itertools_num::linspace; | ||
use ordered_float::NotNan; | ||
|
||
/// Integrate over an interval of type T with a given density function while trying to minimize | ||
/// the number of grid points evaluated and still hit the maximum likelihood point. | ||
/// This is achieved via a binary search over the grid points. | ||
/// The assumption is that the density is unimodal. If that is not the case, | ||
/// the binary search will not find the maximum and the integral can miss features. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use bio::stats::probs::adaptive_integration::ln_integrate_exp; | ||
/// use bio::stats::probs::{Prob, LogProb}; | ||
/// use statrs::distribution::{Normal, Continuous}; | ||
/// use statrs::statistics::Distribution; | ||
/// use ordered_float::NotNan; | ||
/// use approx::abs_diff_eq; | ||
/// | ||
/// let ndist = Normal::new(0.0, 1.0).unwrap(); | ||
/// | ||
/// let integral = ln_integrate_exp( | ||
/// |x| LogProb::from(Prob(ndist.pdf(*x))), | ||
/// NotNan::new(-1.0).unwrap(), | ||
/// NotNan::new(1.0).unwrap(), | ||
/// NotNan::new(0.01).unwrap() | ||
/// ); | ||
/// abs_diff_eq!(integral.exp(), 0.682, epsilon=0.01); | ||
/// ``` | ||
pub fn ln_integrate_exp<T, F>( | ||
mut density: F, | ||
min_point: T, | ||
max_point: T, | ||
max_resolution: T, | ||
) -> LogProb | ||
where | ||
T: Copy | ||
+ Add<Output = T> | ||
+ Sub<Output = T> | ||
+ Div<Output = T> | ||
+ Div<NotNan<f64>, Output = T> | ||
+ Mul<Output = T> | ||
+ Into<f64> | ||
+ From<f64> | ||
+ Ord | ||
+ Debug | ||
+ Hash, | ||
F: FnMut(T) -> LogProb, | ||
f64: From<T>, | ||
{ | ||
let mut probs = HashMap::new(); | ||
|
||
let mut grid_point = |point, probs: &mut HashMap<_, _>| { | ||
probs.insert(point, density(point)); | ||
point | ||
}; | ||
let middle_grid_point = |left: T, right: T| (right + left) / NotNan::new(2.0).unwrap(); | ||
// METHOD: | ||
// Step 1: perform binary search for maximum likelihood point | ||
// Remember all points. | ||
let mut left = grid_point(min_point, &mut probs); | ||
let mut right = grid_point(max_point, &mut probs); | ||
let mut first_middle = None; | ||
let mut middle = None; | ||
|
||
while (((right - left) >= max_resolution) && left < right) || middle.is_none() { | ||
middle = Some(grid_point(middle_grid_point(left, right), &mut probs)); | ||
|
||
if first_middle.is_none() { | ||
first_middle = middle; | ||
} | ||
|
||
let left_prob = probs.get(&left).unwrap(); | ||
let right_prob = probs.get(&right).unwrap(); | ||
|
||
if left_prob > right_prob { | ||
// investigate left window more closely | ||
right = middle.unwrap(); | ||
} else { | ||
// investigate right window more closely | ||
left = middle.unwrap(); | ||
} | ||
} | ||
// METHOD: add additional grid point in the initially abandoned arm | ||
if middle < first_middle { | ||
grid_point( | ||
middle_grid_point(first_middle.unwrap(), max_point), | ||
&mut probs, | ||
); | ||
} else { | ||
grid_point( | ||
middle_grid_point(min_point, first_middle.unwrap()), | ||
&mut probs, | ||
); | ||
} | ||
// METHOD additionally investigate small interval around the optimum | ||
for point in linspace( | ||
cmp::max( | ||
middle.unwrap() - (max_resolution.into() * 3.0).into(), | ||
min_point, | ||
) | ||
.into(), | ||
middle.unwrap().into(), | ||
4, | ||
) | ||
.take(3) | ||
.chain( | ||
linspace( | ||
middle.unwrap().into(), | ||
cmp::min( | ||
middle.unwrap() + (max_resolution.into() * 3.0).into(), | ||
max_point, | ||
) | ||
.into(), | ||
4, | ||
) | ||
.skip(1), | ||
) { | ||
grid_point(point.into(), &mut probs); | ||
} | ||
|
||
let sorted_grid_points: Vec<f64> = probs.keys().sorted().map(|point| (*point).into()).collect(); | ||
|
||
// METHOD: | ||
// Step 2: integrate over grid points visited during the binary search. | ||
LogProb::ln_trapezoidal_integrate_grid_exp::<f64, _>( | ||
|_, g| *probs.get(&T::from(g)).unwrap(), | ||
&sorted_grid_points, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters