Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for regression prediction, ie raw values instead of cla… #7

Merged
merged 3 commits into from Jan 7, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 60 additions & 0 deletions include/dll/dbn_impl.hpp
Expand Up @@ -1594,6 +1594,66 @@ struct dbn final {

return fine_tune_ae(*generator, max_epochs);
}

// Fine tune for regression

/*!
* \brief Fine tune the network for regression.
* \param generator Generator for samples and outputs
* \param max_epochs The maximum number of epochs to train the network for.
* \return The final average loss
*/
template <typename Generator, cpp_enable_iff(is_generator<Generator>)>
weight fine_tune_reg(Generator& generator, size_t max_epochs) {
dll::auto_timer timer("net:train:ft:reg");

validate_generator(generator);

dll::dbn_trainer<this_type> trainer;
return trainer.train(*this, generator, max_epochs);
}

/*!
* \brief Fine tune the network for regression.
* \param inputs A container containing all the samples
* \param outputs A container containing the correct results
* \param max_epochs The maximum number of epochs to train the network for.
* \return The final average loss
*/
template <typename Inputs, typename Outputs>
weight fine_tune_reg(const Inputs& inputs, const Outputs& outputs, size_t max_epochs) {
// Create generator around the containers
cpp_assert(inputs.size() == outputs.size(), "The number of inputs does not match the number of outputs for training.");
auto generator = dll::make_generator(
inputs, outputs,
inputs.size(), output_size(), reg_generator_t{});
generator->set_safe();

return fine_tune_reg(*generator, max_epochs);
}

/*!
* \brief Fine tune the network for regression.
* \param in_first Iterator to the first sample
* \param in_last Iterator to the last sample
* \param out_first Iterator to the first output
* \param out_last Iterator to the last output
* \param max_epochs The maximum number of epochs to train the network for.
* \return The final average loss
*/
template <typename InIterator, typename OutIterator>
weight fine_tune_reg(InIterator&& in_first, InIterator&& in_last, OutIterator&& out_first, OutIterator&& out_last, size_t max_epochs) {
// Create generator around the iterators
cpp_assert(std::distance(in_first, in_last) == std::distance(out_first, out_last), "The number of inputs does not match the number of outputs for training.");
auto generator = make_generator(
std::forward<InIterator>(in_first), std::forward<InIterator>(in_last),
std::forward<OutIterator>(out_first), std::forward<OutIterator>(out_last),
std::distance(in_first, in_last), output_size(), reg_generator_t{});

generator->set_safe();

return fine_tune_reg(*generator, max_epochs);
}

template <size_t I, typename Input>
auto prepare_output() const {
Expand Down