diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index 9601911e0dec8..bb8552d38d225 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -62,10 +62,12 @@ TEST(DataTest, TransformCallsGetApplyCorrectly) { // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk // contains 10, 5, 20 examples respectively. + struct DummyChunkDataReader - : public datasets::ChunkDataReader> { + : public datasets::ChunkDataReader { public: - using BatchType = std::vector; + using BatchType = datasets::ChunkDataReader::ChunkType; + using DataType = datasets::ChunkDataReader::ExampleType; /// Read an entire chunk. BatchType read_chunk(size_t chunk_index) override { @@ -1650,7 +1652,7 @@ TEST(DataLoaderTest, ChunkDataSetGetBatch) { for (auto iterator = data_loader->begin(); iterator != data_loader->end(); ++iterator, ++iteration_count) { - std::vector& batch = *iterator; + DummyChunkDataReader::BatchType& batch = *iterator; ASSERT_EQ(batch.size(), batch_size); // When prefetch_count is equal to 1 and no worker thread, the batch @@ -1709,9 +1711,9 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) { TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) { struct DummyEmptyChunkDataReader - : datasets::ChunkDataReader> { + : datasets::ChunkDataReader { public: - using BatchType = std::vector; + using BatchType = datasets::ChunkDataReader::ChunkType; BatchType read_chunk(size_t chunk_index) override { return {}; @@ -1752,9 +1754,9 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) { } TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) { - struct D : public datasets::ChunkDataReader> { + struct D : public datasets::ChunkDataReader { public: - using BatchType = std::vector; + using BatchType = datasets::ChunkDataReader::ChunkType; BatchType read_chunk(size_t chunk_index) override { BatchType batch_data(10, 0); @@ -1791,7 +1793,7 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) { for (auto iterator = data_loader->begin(); iterator != data_loader->end(); ++iterator) { - std::vector batch = *iterator; + DummyChunkDataReader::BatchType batch = *iterator; auto batch_size = batch.size(); if (batch_size == 17) { ASSERT_TRUE(batch.size() == 17 || batch.size() == 3); @@ -1825,8 +1827,8 @@ TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) { samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler(); auto data_loader = torch::data::make_data_loader( - dataset.map(transforms::BatchLambda, int>( - [](std::vector batch) { + dataset.map(transforms::BatchLambda( + [](DummyChunkDataReader::BatchType batch) { return std::accumulate(batch.begin(), batch.end(), 0); })), DataLoaderOptions(batch_size).workers(0)); @@ -1869,8 +1871,8 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) { samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler(); auto data_loader = torch::data::make_data_loader( - dataset.map(transforms::BatchLambda, int>( - [](std::vector batch) { + dataset.map(transforms::BatchLambda( + [](DummyChunkDataReader::BatchType batch) { return std::accumulate(batch.begin(), batch.end(), 0); })), DataLoaderOptions(batch_size).workers(0)); @@ -1878,4 +1880,4 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) { // to fill the batch buffer but it is not draining. Still we need to exit // cleanly. auto iterator = data_loader->begin(); -} +} \ No newline at end of file diff --git a/torch/csrc/api/include/torch/data/datasets/chunk.h b/torch/csrc/api/include/torch/data/datasets/chunk.h index b5d66f8738869..74142591bf997 100644 --- a/torch/csrc/api/include/torch/data/datasets/chunk.h +++ b/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -1,6 +1,11 @@ #pragma once +#include +#include #include +#include +#include +#include namespace torch { namespace data { @@ -12,10 +17,11 @@ namespace datasets { /// A chunk could be an entire file, such as an audio data file or an image, /// or part of a file in the case of a large text-file split based on seek /// positions. -template >> +template > class ChunkDataReader { public: - using ChunkType = Chunk; + using ChunkType = ChunkType_; + using ExampleType = ExampleType_; /// Read an entire chunk. virtual ChunkType read_chunk(size_t chunk_index) = 0; @@ -34,7 +40,7 @@ namespace detail { /// return. If the cache is empty, it either waits to load more chunks or return /// null if all chunks are loaded. template < - typename UnwrappedBatch = std::vector>, + typename UnwrappedBatch, typename ExampleSampler = samplers::RandomSampler> class BatchDataBuffer { public: