Skip to content

Commit

Permalink
[ProgressBar] Add a boolean flag to the progress bar.
Browse files Browse the repository at this point in the history
This enable us to create more complex behaviours. The progress
bar now takes also a std::function (lambda) that can be used
to stop the loop execution.
  • Loading branch information
geektoni committed Jun 2, 2017
1 parent 48c1bd3 commit a17321f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 32 deletions.
114 changes: 82 additions & 32 deletions src/shogun/base/progress.h
Expand Up @@ -36,6 +36,7 @@
#ifndef __SG_PROGRESS_H__
#define __SG_PROGRESS_H__

#include <functional>
#include <iterator>
#include <memory>
#include <string>
Expand Down Expand Up @@ -140,6 +141,11 @@ namespace shogun
lock.unlock();
}

void premature_end()
{
m_current_value.store(m_max_value - 1);
}

/** @return last progress as a percentage */
inline float64_t get_last_progress() const
{
Expand Down Expand Up @@ -336,31 +342,42 @@ namespace shogun
PRange(Range<T> range, const SGIO& io) : m_range(range)
{
set_up_range();
m_printer =
std::make_shared<ProgressPrinter>(io, end_range, begin_range);
m_printer = std::make_shared<ProgressPrinter>(
io, m_end_range, m_begin_range);
}
PRange(Range<T> range, const SGIO& io, const SG_PRG_MODE mode)
: m_range(range)
PRange(Range<T> range, const SGIO& io, std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, end_range, begin_range, mode);
io, m_end_range, m_begin_range);
}
PRange(Range<T> range, const SGIO& io, const std::string prefix)
: m_range(range)
PRange(
Range<T> range, const SGIO& io, const SG_PRG_MODE mode,
std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, end_range, begin_range, prefix);
io, m_end_range, m_begin_range, mode);
}
PRange(
Range<T> range, const SGIO& io, const std::string prefix,
const SG_PRG_MODE mode)
: m_range(range)
std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, m_end_range, m_begin_range, prefix);
}
PRange(
Range<T> range, const SGIO& io, const std::string prefix,
const SG_PRG_MODE mode, std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, end_range, begin_range, prefix, mode);
io, m_end_range, m_begin_range, prefix, mode);
}

/** @class Wrapper for Range<T>::Iterator spawned by @ref PRange. */
Expand All @@ -369,16 +386,19 @@ namespace shogun
public:
PIterator(
typename Range<T>::Iterator value,
std::shared_ptr<ProgressPrinter> shrd_ptr)
: m_value(value), m_printer(shrd_ptr)
std::shared_ptr<ProgressPrinter> shrd_ptr,
std::function<bool()> condition)
: m_value(value), m_printer(shrd_ptr), m_condition(condition)
{
}
PIterator(const PIterator& other)
: m_value(other.m_value), m_printer(other.m_printer)
: m_value(other.m_value), m_printer(other.m_printer),
m_condition(other.m_condition)
{
}
PIterator(PIterator&& other)
: m_value(other.m_value), m_printer(other.m_printer)
: m_value(other.m_value), m_printer(other.m_printer),
m_condition(other.m_condition)
{
}
PIterator& operator=(const PIterator&) = delete;
Expand All @@ -390,7 +410,7 @@ namespace shogun
m_value++;
return *this;
}
PIterator& operator++(int)
PIterator operator++(int)
{
PIterator tmp(*this);
++*this;
Expand All @@ -405,19 +425,23 @@ namespace shogun
}
bool operator!=(const PIterator& other)
{
return this->m_value != other.m_value;
bool result = evaluate_condition();
return (this->m_value != other.m_value) && result;
}
bool operator==(const PIterator& other)
{
return this->m_value == other.m_value;
bool result = evaluate_condition();
return (this->m_value == other.m_value) && result;
}
bool operator>(const PIterator& other)
{
return this->m_value > other.m_value;
bool result = evaluate_condition();
return (this->m_value > other.m_value) && result;
}
bool operator<(const PIterator& other)
{
return !(*this > other);
bool result = evaluate_condition();
return !(this->m_value > other.m_value) && result;
}

T operator-(PIterator& other)
Expand All @@ -431,24 +455,46 @@ namespace shogun
return this->m_value += other;
}

inline bool check_condition() const
{
return m_condition();
}

private:
bool evaluate_condition()
{
if (!m_condition())
{
m_printer->premature_end();
m_printer->print_progress();
}
return m_condition();
}

/* The wrapped range */
typename Range<T>::Iterator m_value;
/* The ProgressPrinter object which will be used to show the
* progress bar*/
std::shared_ptr<ProgressPrinter> m_printer;
std::function<bool()> m_condition;
};

/** Create the iterator that corresponds to the start of the range*/
PIterator begin() const
{
return PIterator(m_range.begin(), m_printer);
return PIterator(m_range.begin(), m_printer, m_condition);
}

/** Create the iterator that corresponds to the start of the range*/
PIterator begin(std::function<bool()> condition) const
{
return PIterator(m_range.begin(), m_printer, condition);
}

/** Create the iterator that corresponds to the end of the iterator*/
PIterator end() const
{
return PIterator(m_range.end(), m_printer);
return PIterator(m_range.end(), m_printer, m_condition);
}

/** @return last progress as a percentage */
Expand All @@ -466,16 +512,17 @@ namespace shogun
private:
void set_up_range()
{
begin_range = *(m_range.begin());
end_range = *(m_range.end());
m_begin_range = *(m_range.begin());
m_end_range = *(m_range.end());
}

/** Range we iterate over */
Range<T> m_range;
/** Observer that will print the actual progress bar */
std::shared_ptr<ProgressPrinter> m_printer;
float64_t begin_range;
float64_t end_range;
float64_t m_begin_range;
float64_t m_end_range;
std::function<bool()> m_condition = []() { return true; };
};

/** Creates @ref PRange given a range.
Expand All @@ -488,10 +535,11 @@ namespace shogun
* @param io SGIO object
*/
template <typename T>
inline PRange<T>
progress(Range<T> range, const SGIO& io, SG_PRG_MODE mode = UTF8)
inline PRange<T> progress(
Range<T> range, const SGIO& io, SG_PRG_MODE mode = UTF8,
std::function<bool()> condition = []() { return true; })
{
return PRange<T>(range, io, mode);
return PRange<T>(range, io, mode, condition);
}

/** Creates @ref PRange given a range that uses the global SGIO
Expand All @@ -503,9 +551,11 @@ namespace shogun
* @param range range used
*/
template <typename T>
inline PRange<T> progress(Range<T> range, SG_PRG_MODE mode = UTF8)
inline PRange<T> progress(
Range<T> range, SG_PRG_MODE mode = UTF8,
std::function<bool()> condition = []() { return true; })
{
return PRange<T>(range, *sg_io, mode);
return PRange<T>(range, *sg_io, mode, condition);
}
};
#endif /* __SG_PROGRESS_H__ */
#endif /* __SG_PROGRESS_H__ */
14 changes: 14 additions & 0 deletions tests/unit/base/PRange_unittest.cc
Expand Up @@ -81,6 +81,20 @@ TEST(PRange, progress_correct_bounds_negative)
}
}

TEST(PRange, lambda_stop)
{
int test = 6;
/* Stops before the 4th iteration */
for (auto i :
progress(range(0, 6), range_io, UTF8, [&]() { return test > 3; }))
{
(void)i;
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
test--;
}
EXPECT_EQ(test, 3);
}

TEST(PRange, DISABLED_progress_incorrect_bounds_positive)
{
range_io.enable_progress();
Expand Down

0 comments on commit a17321f

Please sign in to comment.