Skip to content

Commit

Permalink
feat: make timer optional. Optimise perf when no observers or checkpo…
Browse files Browse the repository at this point in the history
…ints
  • Loading branch information
sdd committed Apr 29, 2021
1 parent 6deb382 commit a67f56a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 23 deletions.
52 changes: 37 additions & 15 deletions src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub struct Executor<O: ArgminOp, S> {
checkpoint: ArgminCheckpoint,
/// Indicates whether Ctrl-C functionality should be active or not
ctrlc: bool,
/// Indicates whether to time execution or not
timer: bool,
}

impl<O, S> Executor<O, S>
Expand All @@ -54,6 +56,7 @@ where
observers: Observer::new(),
checkpoint: ArgminCheckpoint::default(),
ctrlc: true,
timer: true,
}
}

Expand Down Expand Up @@ -104,7 +107,7 @@ where

/// Run the executor
pub fn run(mut self) -> Result<ArgminResult<O>, Error> {
let total_time = instant::Instant::now();
let total_time = if self.timer { Some(instant::Instant::now()) } else { None };

let running = Arc::new(AtomicBool::new(true));

Expand All @@ -131,17 +134,22 @@ where

// let mut op_wrapper = OpWrapper::new(&self.op);
let init_data = self.solver.init(&mut self.op, &self.state)?;

let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););


// If init() returned something, deal with it
if let Some(data) = init_data {
if let Some(data) = &init_data {
self.update(&data)?;
logs = logs.merge(&mut data.get_kv());
}

// Observe after init
self.observers.observe_init(S::NAME, &logs)?;
if !self.observers.is_empty() {
let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););

if let Some(data) = init_data {
logs = logs.merge(&mut data.get_kv());
}

// Observe after init
self.observers.observe_init(S::NAME, &logs)?;
}

self.state.set_func_counts(&self.op);

Expand All @@ -162,29 +170,37 @@ where
}

// Start time measurement
let start = instant::Instant::now();
let start = if self.timer { Some(instant::Instant::now()) } else { None };

let data = self.solver.next_iter(&mut self.op, &self.state)?;

self.state.set_func_counts(&self.op);

// End time measurement
let duration = start.elapsed();
let duration = if self.timer { Some(start.unwrap().elapsed()) } else { None };

self.update(&data)?;

let log = data.get_kv().merge(&mut make_kv!(
"time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
));
if !self.observers.is_empty() {
let mut log = data.get_kv();

self.observers.observe_iter(&self.state, &log)?;
if self.timer {
let duration = duration.unwrap();
log = log.merge(&mut make_kv!(
"time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
));
}
self.observers.observe_iter(&self.state, &log)?;
}

// increment iteration number
self.state.increment_iter();

self.checkpoint.store_cond(&self, self.state.get_iter())?;

self.state.time(total_time.elapsed());
if self.timer {
total_time.map(|total_time| self.state.time(Some(total_time.elapsed())));
}

// Check if termination occured inside next_iter()
if self.state.terminated() {
Expand Down Expand Up @@ -270,4 +286,10 @@ where
self.ctrlc = ctrlc;
self
}

/// Turn timer on or off (default: on)
pub fn timer(mut self, timer: bool) -> Self {
self.timer = timer;
self
}
}
8 changes: 4 additions & 4 deletions src/core/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct IterState<O: ArgminOp> {
/// Number of modify evaluations so far
pub modify_func_count: u64,
/// Time required so far
pub time: instant::Duration,
pub time: Option<instant::Duration>,
/// Reason of termination
pub termination_reason: TerminationReason,
}
Expand Down Expand Up @@ -137,7 +137,7 @@ impl<O: ArgminOp> IterState<O> {
hessian_func_count: 0,
jacobian_func_count: 0,
modify_func_count: 0,
time: instant::Duration::new(0, 0),
time: Some(instant::Duration::new(0, 0)),
termination_reason: TerminationReason::NotTerminated,
}
}
Expand Down Expand Up @@ -213,7 +213,7 @@ impl<O: ArgminOp> IterState<O> {
TerminationReason,
"Set termination_reason"
);
setter!(time, instant::Duration, "Set time required so far");
setter!(time, Option<instant::Duration>, "Set time required so far");
getter!(param, O::Param, "Returns current parameter vector");
getter!(prev_param, O::Param, "Returns previous parameter vector");
getter!(best_param, O::Param, "Returns best parameter vector");
Expand Down Expand Up @@ -270,7 +270,7 @@ impl<O: ArgminOp> IterState<O> {
TerminationReason,
"Get termination_reason"
);
getter!(time, instant::Duration, "Get time required so far");
getter!(time, Option<instant::Duration>, "Get time required so far");
getter_option!(grad, O::Param, "Returns gradient");
getter_option!(prev_grad, O::Param, "Returns previous gradient");
getter_option!(hessian, O::Hessian, "Returns current Hessian");
Expand Down
5 changes: 5 additions & 0 deletions src/core/observers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ impl<O: ArgminOp> Observer<O> {
self.observers.push((Arc::new(Mutex::new(observer)), mode));
self
}

/// Returns true if `observers` is empty
pub fn is_empty(&self) -> bool {
self.observers.is_empty()
}
}

/// By implementing `Observe` for `Observer` we basically allow a set of `Observer`s to be used
Expand Down
14 changes: 10 additions & 4 deletions src/core/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ impl ArgminCheckpoint {
pub fn name(&self) -> String {
self.name.clone()
}

/// Get filename for checkpoint
#[inline]
fn filename(&self) -> String {
let mut filename = self.name();
filename.push_str(".arg");
filename
}

/// Set mode of checkpoint
#[inline]
Expand All @@ -126,11 +134,9 @@ impl ArgminCheckpoint {
/// Write checkpoint based on the desired `CheckpointMode`
#[inline]
pub fn store_cond<T: Serialize>(&self, executor: &T, iter: u64) -> Result<(), Error> {
let mut filename = self.name();
filename.push_str(".arg");
match self.mode {
CheckpointMode::Always => self.store(executor, filename)?,
CheckpointMode::Every(it) if iter % it == 0 => self.store(executor, filename)?,
CheckpointMode::Always => self.store(executor, self.filename())?,
CheckpointMode::Every(it) if iter % it == 0 => self.store(executor, self.filename())?,
CheckpointMode::Never | CheckpointMode::Every(_) => {}
};
Ok(())
Expand Down

0 comments on commit a67f56a

Please sign in to comment.