Skip to content

Commit

Permalink
added first draft of train-test data split
Browse files Browse the repository at this point in the history
  • Loading branch information
lambday authored and karlnapf committed Jul 4, 2016
1 parent d35d8df commit af0b2bf
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 6 deletions.
79 changes: 75 additions & 4 deletions src/shogun/statistical_testing/internals/DataFetcher.cpp
Expand Up @@ -24,15 +24,16 @@
using namespace shogun;
using namespace internal;

DataFetcher::DataFetcher() : m_num_samples(0), m_samples(nullptr)
DataFetcher::DataFetcher() : m_num_samples(0), m_samples(nullptr), train_test_subset_used(false)
{
}

DataFetcher::DataFetcher(CFeatures* samples) : m_samples(samples)
DataFetcher::DataFetcher(CFeatures* samples) : m_samples(samples), train_test_subset_used(false)
{
REQUIRE(m_samples!=nullptr, "Samples cannot be null!\n");
SG_REF(m_samples);
m_num_samples=m_samples->get_num_vectors();
m_train_test_details.set_total_num_samples(m_num_samples);
}

DataFetcher::~DataFetcher()
Expand All @@ -46,12 +47,82 @@ const char* DataFetcher::get_name() const
return "DataFetcher";
}

void DataFetcher::set_train_test_ratio(float64_t train_test_ratio)
{
m_num_samples=m_train_test_details.get_total_num_samples();
index_t num_training_samples=m_num_samples*train_test_ratio/(train_test_ratio+1);
m_train_test_details.set_num_training_samples(num_training_samples);
}

float64_t DataFetcher::get_train_test_ratio() const
{
return float64_t(m_train_test_details.get_num_training_samples())/m_train_test_details.get_num_test_samples();
}

void DataFetcher::set_train_mode(bool train_mode)
{
index_t start_index=0;
if (train_mode)
{
m_num_samples=m_train_test_details.get_num_training_samples();
if (m_num_samples==0)
SG_SERROR("The number of training samples is 0! Please set a valid train-test ratio\n");
}
else
{
m_num_samples=m_train_test_details.get_num_test_samples();
start_index=m_train_test_details.get_num_training_samples();
if (start_index==0)
{
if (train_test_subset_used)
{
m_samples->remove_subset();
train_test_subset_used=false;
}
return;
}
}
SGVector<index_t> inds(m_num_samples);
std::iota(inds.data(), inds.data()+inds.size(), start_index);
if (train_test_subset_used)
{
m_samples->remove_subset();
}
m_samples->add_subset(inds);
train_test_subset_used=true;
}

void DataFetcher::set_xvalidation_mode(bool xvalidation_mode)
{
// using fetcher_type=std::unique_ptr<DataFetcher>;
// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f)
// {
// f->set_xvalidation_mode(xvalidation_mode);
// });
}

index_t DataFetcher::get_num_folds() const
{
// REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n");
// return fetchers[0]->get_train_test_ratio();
return 0;
}

void DataFetcher::use_fold(index_t idx)
{
// using fetcher_type=std::unique_ptr<DataFetcher>;
// std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f)
// {
// f->use_fold(idx);
// });
}

void DataFetcher::start()
{
REQUIRE(m_num_samples>0, "Number of samples is 0!\n");
if (m_block_details.m_blocksize==0)
if (m_block_details.m_blocksize==0 || m_block_details.m_blocksize>m_num_samples)
{
SG_SINFO("Block details not set! Fetching entire data (%d samples)!\n", m_num_samples);
SG_SINFO("Block details invalid! Fetching entire data (%d samples)!\n", m_num_samples);
m_block_details.with_blocksize(m_num_samples);
}
m_block_details.m_total_num_blocks=m_num_samples/m_block_details.m_blocksize;
Expand Down
10 changes: 10 additions & 0 deletions src/shogun/statistical_testing/internals/DataFetcher.h
Expand Up @@ -30,6 +30,7 @@

#include <shogun/lib/common.h>
#include <shogun/statistical_testing/internals/BlockwiseDetails.h>
#include <shogun/statistical_testing/internals/TrainTestDetails.h>

#ifndef DATA_FETCHER_H__
#define DATA_FETCHER_H__
Expand All @@ -51,6 +52,13 @@ class DataFetcher
public:
DataFetcher(CFeatures* samples);
virtual ~DataFetcher();
void set_train_test_ratio(float64_t train_test_ratio);
float64_t get_train_test_ratio() const;
void set_train_mode(bool train_mode);
void set_xvalidation_mode(bool xvalidation_mode);
index_t get_num_folds() const;
void use_fold(index_t idx);

virtual void start();
virtual CFeatures* next();
virtual void reset();
Expand All @@ -61,9 +69,11 @@ class DataFetcher
protected:
DataFetcher();
BlockwiseDetails m_block_details;
TrainTestDetails m_train_test_details;
index_t m_num_samples;
private:
CFeatures* m_samples;
bool train_test_subset_used;
};

}
Expand Down
37 changes: 35 additions & 2 deletions src/shogun/statistical_testing/internals/DataManager.cpp
Expand Up @@ -41,6 +41,9 @@
using namespace shogun;
using namespace internal;

// TODO add nullptr check before calling the methods on actual fetchers
// this would be where someone calls the other methofds before setiing the sameples

DataManager::DataManager(size_t num_distributions)
{
SG_SDEBUG("Data manager instance initialized with %d data sources!\n", num_distributions);
Expand Down Expand Up @@ -195,28 +198,58 @@ const index_t DataManager::blocksize_at(size_t i) const

void DataManager::set_train_test_ratio(float64_t train_test_ratio)
{
SG_SDEBUG("Entering!\n");
using fetcher_type=std::unique_ptr<DataFetcher>;
std::for_each(fetchers.begin(), fetchers.end(), [&train_test_ratio](fetcher_type& f)
{
f->set_train_test_ratio(train_test_ratio);
});
SG_SDEBUG("Leaving!\n");
}

float64_t DataManager::get_train_test_ratio() const
{
return 0;
REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n");
return fetchers[0]->get_train_test_ratio();
}

void DataManager::set_train_mode(bool train_mode)
{
SG_SDEBUG("Entering!\n");
using fetcher_type=std::unique_ptr<DataFetcher>;
std::for_each(fetchers.begin(), fetchers.end(), [&train_mode](fetcher_type& f)
{
f->set_train_mode(train_mode);
});
SG_SDEBUG("Leaving!\n");
}

void DataManager::set_xvalidation_mode(bool xvalidation_mode)
{
SG_SDEBUG("Entering!\n");
using fetcher_type=std::unique_ptr<DataFetcher>;
std::for_each(fetchers.begin(), fetchers.end(), [&xvalidation_mode](fetcher_type& f)
{
f->set_xvalidation_mode(xvalidation_mode);
});
SG_SDEBUG("Leaving!\n");
}

index_t DataManager::get_num_folds() const
{
return 0;
REQUIRE(fetchers[0]!=nullptr, "Please set the samples first!\n");
return fetchers[0]->get_train_test_ratio();
}

void DataManager::use_fold(index_t idx)
{
SG_SDEBUG("Entering!\n");
using fetcher_type=std::unique_ptr<DataFetcher>;
std::for_each(fetchers.begin(), fetchers.end(), [&idx](fetcher_type& f)
{
f->use_fold(idx);
});
SG_SDEBUG("Leaving!\n");
}

void DataManager::start()
Expand Down
78 changes: 78 additions & 0 deletions src/shogun/statistical_testing/internals/TrainTestDetails.cpp
@@ -0,0 +1,78 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2016 Soumyajit De
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#include <shogun/io/SGIO.h>
#include <shogun/statistical_testing/internals/TrainTestDetails.h>

using namespace shogun;
using namespace internal;

TrainTestDetails::TrainTestDetails() : m_total_num_samples(0), m_num_training_samples(0)
{
}

void TrainTestDetails::set_total_num_samples(index_t total_num_samples)
{
m_total_num_samples=total_num_samples;
}

index_t TrainTestDetails::get_total_num_samples() const
{
return m_total_num_samples;
}

void TrainTestDetails::set_num_training_samples(index_t num_training_samples)
{
REQUIRE(m_total_num_samples>=num_training_samples,
"Number of training samples cannot be greater than the total number of samples!\n");
m_num_training_samples=num_training_samples;
}

index_t TrainTestDetails::get_num_training_samples() const
{
return m_num_training_samples;
}

index_t TrainTestDetails::get_num_test_samples() const
{
return m_total_num_samples-m_num_training_samples;
}
//
//bool TrainTestDetails::is_training_mode() const
//{
//}
//
//void TrainTestDetails::set_train_mode(bool train_mode)
//{
//}
//
//void TrainTestDetails::set_xvalidation_mode(bool xvalidation_mode)
//{
//}
72 changes: 72 additions & 0 deletions src/shogun/statistical_testing/internals/TrainTestDetails.h
@@ -0,0 +1,72 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2016 Soumyajit De
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/

#include <shogun/lib/common.h>

#ifndef TRAIN_TEST_DETAILS_H__
#define TRAIN_TEST_DETAILS_H__

namespace shogun
{

namespace internal
{

/**
* @brief Class that holds train-test details for the data-fetchers.
* There are one instance of this class per fetcher.
*/
class TrainTestDetails
{
friend class DataFetcher;
friend class StreamingDataFetcher;

public:
TrainTestDetails();

void set_total_num_samples(index_t total_num_sampels);
index_t get_total_num_samples() const;

void set_num_training_samples(index_t num_training_samples);
index_t get_num_training_samples() const;
index_t get_num_test_samples() const;

// bool is_training_mode() const;
// void set_train_mode(bool train_mode);
// void set_xvalidation_mode(bool xvalidation_mode);
private:
index_t m_total_num_samples;
index_t m_num_training_samples;
};

}

}
#endif // TRAIN_TEST_DETAILS_H__

0 comments on commit af0b2bf

Please sign in to comment.