Skip to content

Commit

Permalink
Merge pull request #7 from charles-r-earp/regression
Browse files Browse the repository at this point in the history
Added support for regression prediction, ie raw values instead of classification values.
  • Loading branch information
wichtounet committed Jan 7, 2018
2 parents 54b4482 + 6e8c3bc commit 600ea98
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions include/dll/dbn_impl.hpp
Expand Up @@ -1594,6 +1594,66 @@ struct dbn final {

return fine_tune_ae(*generator, max_epochs);
}

// Fine tune for regression

/*!
* \brief Fine tune the network for regression.
* \param generator Generator for samples and outputs
* \param max_epochs The maximum number of epochs to train the network for.
* \return The final average loss
*/
template <typename Generator, cpp_enable_iff(is_generator<Generator>)>
weight fine_tune_reg(Generator& generator, size_t max_epochs) {
dll::auto_timer timer("net:train:ft:reg");

validate_generator(generator);

dll::dbn_trainer<this_type> trainer;
return trainer.train(*this, generator, max_epochs);
}

/*!
* \brief Fine tune the network for regression.
* \param inputs A container containing all the samples
* \param outputs A container containing the correct results
* \param max_epochs The maximum number of epochs to train the network for.
* \return The final average loss
*/
template <typename Inputs, typename Outputs>
weight fine_tune_reg(const Inputs& inputs, const Outputs& outputs, size_t max_epochs) {
// Create generator around the containers
cpp_assert(inputs.size() == outputs.size(), "The number of inputs does not match the number of outputs for training.");
auto generator = dll::make_generator(
inputs, outputs,
inputs.size(), output_size(), reg_generator_t{});
generator->set_safe();

return fine_tune_reg(*generator, max_epochs);
}

/*!
* \brief Fine tune the network for regression.
* \param in_first Iterator to the first sample
* \param in_last Iterator to the last sample
* \param out_first Iterator to the first output
* \param out_last Iterator to the last output
* \param max_epochs The maximum number of epochs to train the network for.
* \return The final average loss
*/
template <typename InIterator, typename OutIterator>
weight fine_tune_reg(InIterator&& in_first, InIterator&& in_last, OutIterator&& out_first, OutIterator&& out_last, size_t max_epochs) {
// Create generator around the iterators
cpp_assert(std::distance(in_first, in_last) == std::distance(out_first, out_last), "The number of inputs does not match the number of outputs for training.");
auto generator = make_generator(
std::forward<InIterator>(in_first), std::forward<InIterator>(in_last),
std::forward<OutIterator>(out_first), std::forward<OutIterator>(out_last),
std::distance(in_first, in_last), output_size(), reg_generator_t{});

generator->set_safe();

return fine_tune_reg(*generator, max_epochs);
}

template <size_t I, typename Input>
auto prepare_output() const {
Expand Down

0 comments on commit 600ea98

Please sign in to comment.