Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for regression prediction, ie raw values instead of cla… #7

Merged
merged 3 commits into from Jan 7, 2018
Merged

Added support for regression prediction, ie raw values instead of cla… #7

merged 3 commits into from Jan 7, 2018

Conversation

charles-r-earp
Copy link
Contributor

…ssification, as a minor modification of fine_tune_ae().

…ssification, as a minor modification of fine_tune_ae().
@charles-r-earp
Copy link
Contributor Author

I just used the auto_encoder code and instead of using the samples as the target output, used an additional set of iterators.

This is the test code.

#include <dll/network.hpp>
#include <dll/neural/dense_layer.hpp>

#include
#include
#include

int main() {
using network_t = dll::dbn_desc<
dll::network_layers<
dll::dense_layer<2, 1>
>

::network_t;
int nsamples = 100;
std::default_random_engine generator;
std::uniform_real_distribution distribution(-10, 10);
std::vector<etl::fast_dyn_matrix<float, 2>> in_data(nsamples);
std::vector<etl::fast_dyn_matrix<float, 1>> out_data(nsamples);
auto line = [](float x){
return x + 1;
};
float x, y;
for(int i=0; i<nsamples; ++i) {
in_data[i][0] = x = distribution(generator);
in_data[i][1] = y = distribution(generator);
out_data[i][0] = line(x) > y ? 1 : 0;
}
std::unique_ptr<network_t> net = std::make_unique<network_t>();
net->fine_tune_reg(in_data.cbegin(), in_data.cend(), out_data.cbegin(), out_data.cend(), 1);
std::cout << "dbn_regression success" << std::endl;
return 0;
}

@wichtounet
Copy link
Owner

Hi,

Thank you very much for this Pull Request. Sorry for the delay in answering, I'm not very active during the holidays.

Everything seems in order.

If I understand correctly: This enable to train regression models by using the same kind of generator as the auto encoder, but not enforcing auto-encoder model. Right ?

Thanks

Baptiste


validate_generator(generator);

//cpp_assert(dll::input_size(layer_get<0>()) == dll::output_size(layer_get<layers - 1>()), "The network is not build //as an autoencoder");
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can directly remove this instead of commenting

*/
template <typename InIterator, typename OutIterator>
weight fine_tune_reg(InIterator&& in_first, InIterator&& in_last, OutIterator&& out_first, OutIterator&& out_last, size_t max_epochs) {
//return "";
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this as well

Removed commented code. The methods use a custom generator type reg_generator_t which currently matches categorical_generator_t.

Only issue is the return value.

return trainer.train(*this, generator, max_epochs);

For regression, average loss seems like a good fit, but the trainer is treating it like a categorical generator, with classification error and nan loss. So the return value and the printed output is invalid, but as far as I can tell the training does work.
@charles-r-earp
Copy link
Contributor Author

Correct. My only issue is the return value.

@wichtounet
Copy link
Owner

wichtounet commented Jan 4, 2018

Seems fine now 👍

What is you issue with the return value ? Seems correct to me.

@wichtounet
Copy link
Owner

If no more issues, I'll merge this now.

@wichtounet wichtounet merged commit 600ea98 into wichtounet:master Jan 7, 2018
@wichtounet
Copy link
Owner

If you are interested, I added a few (dumb) tests for regressions: https://github.com/wichtounet/dll/blob/master/test/src/unit/reg.cpp

You can also see the configuration of the error using MSE rather than CCE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants