Skip to content

Commit

Permalink
Merge pull request #36 from sile/warm-starting-problem
Browse files Browse the repository at this point in the history
Add `WarmStartingProblem`
  • Loading branch information
sile committed Feb 27, 2021
2 parents f9732d9 + 2daef32 commit 30cca7a
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 2 deletions.
7 changes: 6 additions & 1 deletion kurobako_core/src/problem.rs
Expand Up @@ -57,6 +57,12 @@ impl ProblemSpecBuilder {
self
}

/// Sets variables of the value domain of this problem.
pub fn values(mut self, vars: Vec<VariableBuilder>) -> Self {
self.values = vars;
self
}

/// Sets the evaluable steps of this problem.
pub fn steps<I>(mut self, steps: I) -> Self
where
Expand Down Expand Up @@ -346,7 +352,6 @@ enum EvaluableStepsInner {
impl EvaluableStepsInner {
fn new(steps: Vec<u64>) -> Result<Self> {
track_assert!(!steps.is_empty(), ErrorKind::InvalidInput);
track_assert!(steps[0] > 0, ErrorKind::InvalidInput);

for (a, b) in steps.iter().zip(steps.iter().skip(1)) {
track_assert!(a < b, ErrorKind::InvalidInput);
Expand Down
1 change: 1 addition & 0 deletions kurobako_problems/src/lib.rs
Expand Up @@ -8,4 +8,5 @@ pub mod hpobench;
pub mod nasbench;
pub mod sigopt;
pub mod surrogate;
pub mod warm_starting;
pub mod zdt;
136 changes: 136 additions & 0 deletions kurobako_problems/src/warm_starting.rs
@@ -0,0 +1,136 @@
//! A problem for warm-starting optimizations.
use kurobako_core::json::JsonRecipe;
use kurobako_core::problem::{
BoxEvaluator, BoxProblem, BoxProblemFactory, Evaluator, Problem, ProblemFactory, ProblemRecipe,
ProblemSpec, ProblemSpecBuilder,
};
use kurobako_core::registry::FactoryRegistry;
use kurobako_core::rng::ArcRng;
use kurobako_core::trial::{Params, Values};
use kurobako_core::{ErrorKind, Result};
use serde::{Deserialize, Serialize};
use structopt::StructOpt;

/// Recipe of `WarmStartingProblem`.
#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
#[structopt(rename_all = "kebab-case")]
pub struct WarmStartingProblemRecipe {
/// Source problem recipe JSON.
pub source: JsonRecipe,

/// Target problem recipe JSON.
pub target: JsonRecipe,
}

impl ProblemRecipe for WarmStartingProblemRecipe {
type Factory = WarmStartingProblemFactory;

fn create_factory(&self, registry: &FactoryRegistry) -> Result<Self::Factory> {
let source_factory = track!(registry.create_problem_factory_from_json(&self.source))?;
let target_factory = track!(registry.create_problem_factory_from_json(&self.target))?;

Ok(WarmStartingProblemFactory {
source_factory,
target_factory,
})
}
}

/// Factory of `WarmStartingProblem`.
#[derive(Debug)]
pub struct WarmStartingProblemFactory {
source_factory: BoxProblemFactory,
target_factory: BoxProblemFactory,
}
impl ProblemFactory for WarmStartingProblemFactory {
type Problem = WarmStartingProblem;

fn specification(&self) -> Result<ProblemSpec> {
let source_spec = track!(self.source_factory.specification())?;
let target_spec = track!(self.target_factory.specification())?;
track_assert_eq!(
source_spec.params_domain,
target_spec.params_domain,
ErrorKind::InvalidInput
);
track_assert_eq!(
source_spec.values_domain,
target_spec.values_domain,
ErrorKind::InvalidInput
);

let spec = ProblemSpecBuilder::new(&format!("{} with warm starting", target_spec.name))
.params(
target_spec
.params_domain
.variables()
.iter()
.map(|p| p.clone().into())
.collect(),
)
.values(
target_spec
.values_domain
.variables()
.iter()
.map(|p| p.clone().into())
.collect(),
)
.steps(std::iter::once(0).chain(target_spec.steps.iter()));
track!(spec.finish())
}

fn create_problem(&self, rng: ArcRng) -> Result<Self::Problem> {
let source_spec = track!(self.source_factory.specification())?;
let source_last_step = source_spec.steps.last();

let source_problem = track!(self.source_factory.create_problem(rng.clone()))?;
let target_problem = track!(self.target_factory.create_problem(rng))?;
Ok(WarmStartingProblem {
source_last_step,
source_problem,
target_problem,
})
}
}

/// Problem that uses a random forest surrogate model to evaluate parameters.
#[derive(Debug)]
pub struct WarmStartingProblem {
source_last_step: u64,
source_problem: BoxProblem,
target_problem: BoxProblem,
}

impl Problem for WarmStartingProblem {
type Evaluator = WarmStartingEvaluator;

fn create_evaluator(&self, params: Params) -> Result<Self::Evaluator> {
let source_evaluator = track!(self.source_problem.create_evaluator(params.clone()))?;
let target_evaluator = track!(self.target_problem.create_evaluator(params))?;
Ok(WarmStartingEvaluator {
source_last_step: self.source_last_step,
source_evaluator,
target_evaluator,
})
}
}

/// Evaluator of `WarmStartingProblem`.
#[derive(Debug)]
pub struct WarmStartingEvaluator {
source_last_step: u64,
source_evaluator: BoxEvaluator,
target_evaluator: BoxEvaluator,
}

impl Evaluator for WarmStartingEvaluator {
fn evaluate(&mut self, next_step: u64) -> Result<(u64, Values)> {
if next_step == 0 {
let (_, values) = track!(self.source_evaluator.evaluate(self.source_last_step))?;
Ok((0, values))
} else {
track!(self.target_evaluator.evaluate(next_step))
}
}
}
4 changes: 3 additions & 1 deletion src/problem.rs
Expand Up @@ -6,7 +6,7 @@ use kurobako_core::problem::{
use kurobako_core::registry::FactoryRegistry;
use kurobako_core::rng::ArcRng;
use kurobako_core::Result;
use kurobako_problems::{hpobench, nasbench, sigopt, surrogate, zdt};
use kurobako_problems::{hpobench, nasbench, sigopt, surrogate, warm_starting, zdt};
use serde::{Deserialize, Serialize};
use structopt::StructOpt;

Expand Down Expand Up @@ -86,6 +86,7 @@ enum InnerRecipe {
Rank(self::rank::RankProblemRecipe),
Average(self::average::AverageProblemRecipe),
Ln(self::ln::LnProblemRecipe),
WarmStarting(warm_starting::WarmStartingProblemRecipe),
}
impl ProblemRecipe for InnerRecipe {
type Factory = BoxProblemFactory;
Expand All @@ -102,6 +103,7 @@ impl ProblemRecipe for InnerRecipe {
Self::Rank(p) => track!(p.create_factory(registry).map(BoxProblemFactory::new)),
Self::Average(p) => track!(p.create_factory(registry).map(BoxProblemFactory::new)),
Self::Ln(p) => track!(p.create_factory(registry).map(BoxProblemFactory::new)),
Self::WarmStarting(p) => track!(p.create_factory(registry).map(BoxProblemFactory::new)),
}
}
}
Expand Down

0 comments on commit 30cca7a

Please sign in to comment.