Skip to content

Commit

Permalink
Merge pull request #3829 from geektoni/boolean_check_progress
Browse files Browse the repository at this point in the history
[ProgressBar] Add a boolean flag to the progress bar.
  • Loading branch information
vigsterkr committed Jun 5, 2017
2 parents 48c1bd3 + a17321f commit c211d75
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 32 deletions.
114 changes: 82 additions & 32 deletions src/shogun/base/progress.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#ifndef __SG_PROGRESS_H__
#define __SG_PROGRESS_H__

#include <functional>
#include <iterator>
#include <memory>
#include <string>
Expand Down Expand Up @@ -140,6 +141,11 @@ namespace shogun
lock.unlock();
}

void premature_end()
{
m_current_value.store(m_max_value - 1);
}

/** @return last progress as a percentage */
inline float64_t get_last_progress() const
{
Expand Down Expand Up @@ -336,31 +342,42 @@ namespace shogun
PRange(Range<T> range, const SGIO& io) : m_range(range)
{
set_up_range();
m_printer =
std::make_shared<ProgressPrinter>(io, end_range, begin_range);
m_printer = std::make_shared<ProgressPrinter>(
io, m_end_range, m_begin_range);
}
PRange(Range<T> range, const SGIO& io, const SG_PRG_MODE mode)
: m_range(range)
PRange(Range<T> range, const SGIO& io, std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, end_range, begin_range, mode);
io, m_end_range, m_begin_range);
}
PRange(Range<T> range, const SGIO& io, const std::string prefix)
: m_range(range)
PRange(
Range<T> range, const SGIO& io, const SG_PRG_MODE mode,
std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, end_range, begin_range, prefix);
io, m_end_range, m_begin_range, mode);
}
PRange(
Range<T> range, const SGIO& io, const std::string prefix,
const SG_PRG_MODE mode)
: m_range(range)
std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, m_end_range, m_begin_range, prefix);
}
PRange(
Range<T> range, const SGIO& io, const std::string prefix,
const SG_PRG_MODE mode, std::function<bool()> condition)
: m_range(range), m_condition(condition)
{
set_up_range();
m_printer = std::make_shared<ProgressPrinter>(
io, end_range, begin_range, prefix, mode);
io, m_end_range, m_begin_range, prefix, mode);
}

/** @class Wrapper for Range<T>::Iterator spawned by @ref PRange. */
Expand All @@ -369,16 +386,19 @@ namespace shogun
public:
PIterator(
typename Range<T>::Iterator value,
std::shared_ptr<ProgressPrinter> shrd_ptr)
: m_value(value), m_printer(shrd_ptr)
std::shared_ptr<ProgressPrinter> shrd_ptr,
std::function<bool()> condition)
: m_value(value), m_printer(shrd_ptr), m_condition(condition)
{
}
PIterator(const PIterator& other)
: m_value(other.m_value), m_printer(other.m_printer)
: m_value(other.m_value), m_printer(other.m_printer),
m_condition(other.m_condition)
{
}
PIterator(PIterator&& other)
: m_value(other.m_value), m_printer(other.m_printer)
: m_value(other.m_value), m_printer(other.m_printer),
m_condition(other.m_condition)
{
}
PIterator& operator=(const PIterator&) = delete;
Expand All @@ -390,7 +410,7 @@ namespace shogun
m_value++;
return *this;
}
PIterator& operator++(int)
PIterator operator++(int)
{
PIterator tmp(*this);
++*this;
Expand All @@ -405,19 +425,23 @@ namespace shogun
}
bool operator!=(const PIterator& other)
{
return this->m_value != other.m_value;
bool result = evaluate_condition();
return (this->m_value != other.m_value) && result;
}
bool operator==(const PIterator& other)
{
return this->m_value == other.m_value;
bool result = evaluate_condition();
return (this->m_value == other.m_value) && result;
}
bool operator>(const PIterator& other)
{
return this->m_value > other.m_value;
bool result = evaluate_condition();
return (this->m_value > other.m_value) && result;
}
bool operator<(const PIterator& other)
{
return !(*this > other);
bool result = evaluate_condition();
return !(this->m_value > other.m_value) && result;
}

T operator-(PIterator& other)
Expand All @@ -431,24 +455,46 @@ namespace shogun
return this->m_value += other;
}

inline bool check_condition() const
{
return m_condition();
}

private:
bool evaluate_condition()
{
if (!m_condition())
{
m_printer->premature_end();
m_printer->print_progress();
}
return m_condition();
}

/* The wrapped range */
typename Range<T>::Iterator m_value;
/* The ProgressPrinter object which will be used to show the
* progress bar*/
std::shared_ptr<ProgressPrinter> m_printer;
std::function<bool()> m_condition;
};

/** Create the iterator that corresponds to the start of the range*/
PIterator begin() const
{
return PIterator(m_range.begin(), m_printer);
return PIterator(m_range.begin(), m_printer, m_condition);
}

/** Create the iterator that corresponds to the start of the range*/
PIterator begin(std::function<bool()> condition) const
{
return PIterator(m_range.begin(), m_printer, condition);
}

/** Create the iterator that corresponds to the end of the iterator*/
PIterator end() const
{
return PIterator(m_range.end(), m_printer);
return PIterator(m_range.end(), m_printer, m_condition);
}

/** @return last progress as a percentage */
Expand All @@ -466,16 +512,17 @@ namespace shogun
private:
void set_up_range()
{
begin_range = *(m_range.begin());
end_range = *(m_range.end());
m_begin_range = *(m_range.begin());
m_end_range = *(m_range.end());
}

/** Range we iterate over */
Range<T> m_range;
/** Observer that will print the actual progress bar */
std::shared_ptr<ProgressPrinter> m_printer;
float64_t begin_range;
float64_t end_range;
float64_t m_begin_range;
float64_t m_end_range;
std::function<bool()> m_condition = []() { return true; };
};

/** Creates @ref PRange given a range.
Expand All @@ -488,10 +535,11 @@ namespace shogun
* @param io SGIO object
*/
template <typename T>
inline PRange<T>
progress(Range<T> range, const SGIO& io, SG_PRG_MODE mode = UTF8)
inline PRange<T> progress(
Range<T> range, const SGIO& io, SG_PRG_MODE mode = UTF8,
std::function<bool()> condition = []() { return true; })
{
return PRange<T>(range, io, mode);
return PRange<T>(range, io, mode, condition);
}

/** Creates @ref PRange given a range that uses the global SGIO
Expand All @@ -503,9 +551,11 @@ namespace shogun
* @param range range used
*/
template <typename T>
inline PRange<T> progress(Range<T> range, SG_PRG_MODE mode = UTF8)
inline PRange<T> progress(
Range<T> range, SG_PRG_MODE mode = UTF8,
std::function<bool()> condition = []() { return true; })
{
return PRange<T>(range, *sg_io, mode);
return PRange<T>(range, *sg_io, mode, condition);
}
};
#endif /* __SG_PROGRESS_H__ */
#endif /* __SG_PROGRESS_H__ */
14 changes: 14 additions & 0 deletions tests/unit/base/PRange_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ TEST(PRange, progress_correct_bounds_negative)
}
}

TEST(PRange, lambda_stop)
{
int test = 6;
/* Stops before the 4th iteration */
for (auto i :
progress(range(0, 6), range_io, UTF8, [&]() { return test > 3; }))
{
(void)i;
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
test--;
}
EXPECT_EQ(test, 3);
}

TEST(PRange, DISABLED_progress_incorrect_bounds_positive)
{
range_io.enable_progress();
Expand Down

0 comments on commit c211d75

Please sign in to comment.