-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C++ Frontend] Implement DataLoader #11918
Conversation
dae4b10
to
561fa55
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
7d5e865
to
d87492e
Compare
DataLoader( | ||
DatasetType dataset, | ||
DataLoaderOptions options, | ||
SamplerType sampler) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
template < | ||
typename DatasetType, | ||
typename SamplerType = samplers::RandomSampler<>> | ||
class DataLoader { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
template < | ||
typename DatasetType, | ||
typename SamplerType = samplers::RandomSampler<>, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
namespace datasets { | ||
class TensorDataset : public Dataset<TensorDataset, TensorExample> { | ||
public: | ||
explicit TensorDataset(std::vector<Tensor> tensors) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
class DataShuttle { | ||
public: | ||
using JobType = Job; | ||
using ResultType = Result; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
using typename BatchDataset<S, B>::BatchType; | ||
using ExampleType = E; | ||
|
||
virtual ExampleType index(size_t index) = 0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Tensor apply_batch(std::vector<Tensor> tensors) override { | ||
return torch::stack(tensors); | ||
} | ||
}; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
OutputBatchType output_batch; | ||
torch::detail::reserve_capacity(output_batch, input_batch.size()); | ||
for (auto&& input : input_batch) { | ||
output_batch.insert(output_batch.end(), apply(std::move(input))); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return torch::from_blob(buffer.data(), buffer.size(), torch::kByte) | ||
.reshape({count, 1, kImageRows, kImageColumns}) | ||
.to(torch::kFloat32) | ||
.div(255); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
std::vector<char> buffer(count); | ||
labels.read(buffer.data(), count); | ||
return torch::from_blob(buffer.data(), count, torch::kByte).to(torch::kInt64); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
78b49ed
to
4c702b9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM. I have some comments and suggestions with varying levels of priority, so it would be great if you could take a look at them and see if you want to address those.
enforce_ordering(options.enforce_ordering_) {} | ||
|
||
size_t batch_size; | ||
bool drop_last; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
/// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type | ||
/// system, `DataLoaderOptions` allows only setting values. To access values, | ||
/// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions` | ||
/// instance, which will do any necessary coalescing. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
/// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(), | ||
/// output_iterator)` are supported too. | ||
Iterator<Batch> begin() { | ||
reset(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
/// Joins the `DataLoader`'s worker threads and drains internal queues. | ||
void join() { | ||
if (joined_) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
SourceDataset dataset; | ||
AppliedTransform transform; | ||
}; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
/// Fetches the next batch. | ||
void next() override { | ||
// If we didn't get the very first batch yet, get it now. | ||
lazy_initialize(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
/// A `Sampler` is an object that yields indices with which to index into a | ||
/// dataset. | ||
class Sampler { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
void reset() override { | ||
index_ = 0; | ||
} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
/// | ||
/// auto dataset = datasets::MNIST("path/to/mnist") | ||
/// .map(transforms::Collate<Example<>>([](std::vector<Example<>> e) { | ||
/// return std::move(e.front()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
/// A `Collation` for `Example<Tensor, Tensor>` types that stacks all data | ||
/// tensors into one tensor, and all target (label) tensors into one tensor. | ||
template <> | ||
struct Stack<Example<>> : public Collation<Example<>> { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
/// Resets the `Sampler`'s internal state. | ||
/// Typically called before a new epoch. | ||
virtual void reset() = 0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
0c7fe9c
to
d6f81b9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I try to compile this PR, I get the following error:
FAILED: caffe2/torch/CMakeFiles/torch.dir/csrc/api/src/data/datasets/mnist.cpp.o
/usr/bin/c++ -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMAGMA_V2 -DONNX_NAMESPACE=onnx_torch -DTH_BLAS_MKL -DUSE_CUDA -DUSE_GCC_ATOMICS=1 -D_FILE_OFFSET_BITS=64 -D_THP_CORE -Dtorch_EXPORTS -I../aten/src -I. -I../ -I../third_party/protobuf/src -I../cmake/../third_party/benchmark/include -Icaffe2/contrib/aten -I../third_party/onnx -Ithird_party/onnx -I../torch/csrc/api -I../torch/csrc/api/include -I../torch/../aten/src/TH -Icaffe2/torch/../aten/src/TH -I../torch/../aten/src -Icaffe2/torch/../aten/src -Iaten/src -I../torch/../aten/src/THC -Icaffe2/torch/../aten/src/THC -Icaffe2/torch/../aten/src/ATen -I../torch/csrc -I../c10/.. -isystem third_party/gloo -isystem ../cmake/../third_party/gloo -isystem /home/thiagofc/PyTorch/Miniconda3-4.5.4-Linux-x86_64/envs/PyTorch_dev_3.6/include -isystem ../cmake/../third_party/googletest/googletest/include -isystem ../cmake/../third_party/eigen -isystem /home/thiagofc/PyTorch/Miniconda3-4.5.4-Linux-x86_64/envs/PyTorch_dev_3.6/include/python3.6m -isystem /home/thiagofc/PyTorch/Miniconda3-4.5.4-Linux-x86_64/envs/PyTorch_dev_3.6/lib/python3.6/site-packages/numpy/core/include -isystem ../cmake/../third_party/pybind11/include -isystem ../cmake/../third_party/cub -isystem /usr/local/cuda-9.0/include --std=c++11 -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -O2 -fPIC -Wno-narrowing -Wall -Wextra -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -Wno-unused-but-set-variable -Wno-maybe-uninitialized -DHAVE_AVX_CPU_DEFINITION -DHAVE_AVX2_CPU_DEFINITION -O3 -fPIC -DCUDA_HAS_FP16=1 -DHAVE_GCC_GET_CPUID -DUSE_AVX -DUSE_AVX2 -DTH_HAVE_THREAD -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-missing-field-initializers -Wno-write-strings -Wno-unknown-pragmas -Wno-missing-braces -Wno-maybe-uninitialized -std=gnu++11 -MD -MT caffe2/torch/CMakeFiles/torch.dir/csrc/api/src/data/datasets/mnist.cpp.o -MF caffe2/torch/CMakeFiles/torch.dir/csrc/api/src/data/datasets/mnist.cpp.o.d -o caffe2/torch/CMakeFiles/torch.dir/csrc/api/src/data/datasets/mnist.cpp.o -c ../torch/csrc/api/src/data/datasets/mnist.cpp
../torch/csrc/api/src/data/datasets/mnist.cpp:6:24: fatal error: ATen/Error.h: No such file or directory
compilation terminated.
Aten/Error.h seems to be removed from master and this PR too:
$ find . -name Error.h
./third_party/ComputeLibrary/arm_compute/core/Error.h
./third_party/ComputeLibrary/arm_compute/graph/Error.h
d6f81b9
to
2fabdad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
goldsborough is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
When using the updated
|
This PR implements a DataLoader API for the C++ frontend.
The components present in this API largely match the Python API. It consists of:
Dataset
s: Conceptually a function from a set of indices to a batch of examples;Transform
s: A functional transformation of a dataset. AMap<D, T>
for DatasetD
and transformT
is itself a dataset;Sampler
s: Specify a strategy for generating indices for a new batch;DataLoader
, with the ability to automatically parallelize fetching of samples across multiple worker threads;Note that collation functions fall naturally out of the
Map<Dataset, Transform>
abstraction.Things that are missing right now that maybe should be added:
The API was designed to be generalizable to almost any kind of dataset, transform or sampling strategy, while providing a convenient API out of the box. To achieve this, it is quite heavily templatized on various possible input types.
There are many parts to this PR! Right now, I would like feedback on:
I haven't added too many comments yet, as this is fresh out of the oven. Let me know if anything is unclear from the code itself.
There also aren't any tests yet. I will write a comprehensive test suite once we agree on the API and implementation.
@apaszke @ezyang @teng-li @pietern