diff --git a/workbench/src/rbm_dae.cpp b/workbench/src/rbm_dae.cpp index 9205c0e4..88fcdcca 100644 --- a/workbench/src/rbm_dae.cpp +++ b/workbench/src/rbm_dae.cpp @@ -55,6 +55,31 @@ void rbm_dae_batch(const D& dataset){ ae->pretrain_denoising_auto(dataset.training_images, 50, 0.3); } +template +void rbm_cdae_batch(const D& dataset){ + std::cout << " Test RBM Denoising Auto-Encoder" << std::endl; + + using network_t = dll::dbn_desc< + dll::dbn_layers< + dll::rbm_desc<28 * 28, 200, dll::momentum, dll::batch_size<25>>::layer_t, + dll::rbm_desc<200, 100, dll::momentum, dll::batch_size<25>>::layer_t + >, + dll::batch_size<50>, + dll::batch_mode>::dbn_t; + + auto ae = std::make_unique(); + + ae->display(); + + ae->template layer_get<0>().learning_rate = 0.001; + ae->template layer_get<0>().initial_momentum = 0.9; + + ae->template layer_get<1>().learning_rate = 0.001; + ae->template layer_get<1>().initial_momentum = 0.9; + + ae->pretrain_denoising_auto(dataset.training_images, 10, 0.3); +} + } //end of anonymous namespace int main(int /*argc*/, char* /*argv*/ []) { @@ -65,6 +90,7 @@ int main(int /*argc*/, char* /*argv*/ []) { std::cout << n << " samples to test" << std::endl; mnist::binarize_dataset(dataset); + rbm_cdae_batch(dataset); rbm_dae_batch(dataset); rbm_dae(dataset);