Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Jul 20, 2017
1 parent 3fecc3d commit ec0e2ee
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions include/mnist/mnist_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ struct MNIST_dataset {
* \param images The container to fill with the images
* \param path The path to the image file
* \param limit The maximum number of elements to read (0: no limit)
* \param start The elements to ignore at the beginning
* \param func The functor to create the image object
*/
template <typename Container>
bool read_mnist_image_file_flat(Container& images, const std::string& path, std::size_t limit) {
bool read_mnist_image_file_flat(Container& images, const std::string& path, std::size_t limit, std::size_t start = 0) {
auto buffer = read_mnist_file(path, 0x803);

if (buffer) {
Expand All @@ -91,6 +92,9 @@ bool read_mnist_image_file_flat(Container& images, const std::string& path, std:
count = limit;
}

// Ignore "start" first elements
image_buffer += start * (rows * columns);

for (size_t i = 0; i < count; ++i) {
for (size_t j = 0; j < rows * columns; ++j) {
images(i)[j] = *image_buffer++;
Expand Down Expand Up @@ -209,9 +213,10 @@ bool read_mnist_label_file_flat(Container& labels, const std::string& path, std:
* \param labels The container to fill with the labels
* \param path The path to the label file
* \param limit The maximum number of elements to read (0: no limit)
* \param start The elements to avoid at the beginning
*/
template <typename Container>
bool read_mnist_label_file_categorical(Container& labels, const std::string& path, std::size_t limit = 0) {
bool read_mnist_label_file_categorical(Container& labels, const std::string& path, std::size_t limit = 0, std::size_t start = 0) {
auto buffer = read_mnist_file(path, 0x801);

if (buffer) {
Expand All @@ -226,6 +231,9 @@ bool read_mnist_label_file_categorical(Container& labels, const std::string& pat
count = limit;
}

// Ignore "start" first elements
label_buffer += start;

for (size_t i = 0; i < count; ++i) {
labels(i)(static_cast<size_t>(*label_buffer++)) = 1;
}
Expand Down

0 comments on commit ec0e2ee

Please sign in to comment.