Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions test/cpp/api/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1620,3 +1620,44 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
}
}
}

TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
const size_t prefetch_count = 2;
const size_t batch_size = 5;

DummyChunkDataReader data_reader;
samplers::SequentialSampler sampler(0);
datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
datasets::ChunkDatasetOptions(prefetch_count, batch_size));

samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();

auto data_loader = torch::data::make_data_loader(
dataset.map(transforms::BatchLambda<std::vector<int>, int>(
[](std::vector<int> batch) {
return std::accumulate(batch.begin(), batch.end(), 0);
})),
DataLoaderOptions(batch_size).workers(0));

// before we start, the index should be 0.
ASSERT_EQ(chunk_sampler.index(), 0);

size_t sum = 0;
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator) {
sum += *iterator;
}
ASSERT_EQ(sum, 595); // sum([0, 35))
// 3 chunks, and when exhausted the value is already incremented.
ASSERT_EQ(chunk_sampler.index(), 3);
}
95 changes: 43 additions & 52 deletions torch/csrc/api/include/torch/data/datasets/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ class BatchDataBuffer {
using BatchRequestType = typename ExampleSampler::BatchRequestType;

BatchDataBuffer(
size_t num_chunks,
size_t batch_size,
ExampleSampler& example_sampler,
size_t queue_capacity)
: remaining_chunk_count_(num_chunks),
batch_size_(batch_size),
: batch_size_(batch_size),
example_sampler_(example_sampler),
queue_capacity_(queue_capacity),
stop_(false) {}
Expand All @@ -62,11 +60,10 @@ class BatchDataBuffer {
// loaded (i.e. the dataset is exhausted for this epoch)
return (
this->total_example_count_in_queue_ >= batch_size_ ||
this->remaining_chunk_count_ == 0);
this->stop_.load());
});
if (batch_queue_.empty()) {
AT_ASSERT(remaining_chunk_count_ == 0);

AT_ASSERT(this->stop_.load());
// All batches have been retrieved. Return an empty batch.
return nullopt;
}
Expand All @@ -84,26 +81,18 @@ class BatchDataBuffer {
return batch.batch_data;
}

// skip one chunk
void skip_chunk() {
std::unique_lock<std::mutex> lock(queue_mutex_);
AT_ASSERT(remaining_chunk_count_ > 0);
remaining_chunk_count_--;
lock.unlock();
cv_read_.notify_all();
}

/// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
/// threads.
void add_chunk_data(UnwrappedBatchType data) {
std::unique_lock<std::mutex> lock(queue_mutex_);
cv_write_.wait(lock, [this] {
// stop loading if we have preloaded enough data.
return this->total_example_count_in_queue_ < this->queue_capacity_ || stop_.load();
return this->total_example_count_in_queue_ < this->queue_capacity_ ||
stop_.load();
});

if (stop_.load()){
// When stop_ is true, it means this current thread needs to be tore down.
// When stop_ is true, it means no further chunk loading is necessary.
// Return without any further processing.
return;
}
Expand Down Expand Up @@ -150,10 +139,6 @@ class BatchDataBuffer {
batch_queue_.emplace(std::move(current_batch));
}
total_example_count_in_queue_ += data_size;

AT_ASSERT(remaining_chunk_count_ > 0);
remaining_chunk_count_--;

lock.unlock();
cv_read_.notify_all();
}
Expand All @@ -175,9 +160,6 @@ class BatchDataBuffer {
}

batch_queue_.emplace(e_ptr);

AT_ASSERT(remaining_chunk_count_ > 0);
remaining_chunk_count_--;
lock.unlock();
cv_read_.notify_all();
}
Expand All @@ -187,13 +169,10 @@ class BatchDataBuffer {

// notify all writers, wake them from wait to exit current method.
cv_write_.notify_all();
// notify all readers too.
cv_read_.notify_all();
}

/// count of remaining chunk to be loaded. It is initialized with the total
/// chunk count and it decreases when a chunk data is retrieved. When this reaches
/// to 0, no more chunk needs to be loaded.
size_t remaining_chunk_count_ = 0;


/// The batch size is needed to create batches from the chunk data. Similar to
/// regular dataloader where the batches are created with prefetches,
/// BatchDataBuffer perform the batch creation using the provided batch size.
Expand Down Expand Up @@ -310,8 +289,8 @@ class ChunkDataset final
chunk_sampler_(std::move(chunk_sampler)),
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
quit_worker_(false) {
}
quit_worker_(false),
running_preloaders_(0) {}

virtual ~ChunkDataset() {
free_workers();
Expand Down Expand Up @@ -344,24 +323,23 @@ class ChunkDataset final

chunk_reader_.reset();

size_t chunks_to_load = chunk_reader_.chunk_count();
chunk_sampler_.reset(chunks_to_load);
chunk_sampler_.reset(chunk_reader_.chunk_count());

// Throw out any existing cached batch in the buffer and re-creates a new
// chunk buffer.
batch_buffer_ = torch::make_unique<
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
chunks_to_load,
options_.batch_size_,
example_sampler_,
options_.cache_size_);

// create new workers for this new epoch.
quit_worker_ = false;

AT_ASSERT(running_preloaders_ == 0);
for (size_t i = 0; i < options_.preloader_count_; ++i) {
preload_threads_.emplace_back(
[this, i]() { this->preloader(i); });
preload_threads_.emplace_back([this, i]() { this->preloader(i); });
++running_preloaders_;
}
}

Expand All @@ -370,39 +348,45 @@ class ChunkDataset final
return torch::nullopt;
}

// provide a references to chunk sampler. Used mainly in distributed data
// loading to set the epoch number for the sampler.
ChunkSamplerType& chunk_sampler() {
return chunk_sampler_;
}

private:
/// running on worker thread to preload chunk data.
void preloader(size_t id) {
while (!quit_worker_.load()) {
try {
size_t chunk_id = 0;
if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
chunk_id = chunk_sampler_result.value()[0];
} else {
break;
{
std::lock_guard<std::mutex> lock(chunk_index_guard_);
if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
chunk_id = chunk_sampler_result.value()[0];
} else {
break;
}
}
UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_id);
if (data.empty()) {
// if the chunk is empty, skip the current chunk data and move on to
// the next.
batch_buffer_->skip_chunk();
}
else {
if (!data.empty()) { // skip empty chunks.
batch_buffer_->add_chunk_data(std::move(data));
}
} catch (...) {
batch_buffer_->add_chunk_data(std::current_exception());
}
}
--running_preloaders_;
if (running_preloaders_.load() == 0) {
// all preloaders are completed, so we can notify the batch_buffer.
batch_buffer_->stop();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we only stop when the sampler is exhausted. It is also possible that the program wants to exit in the middle of an sweep. In this scenario, the stop_ is not switched to true and thus it could cause a hang from join() called from the destructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The free_workers() call makes the threads to exit and then it tirggers the stop(). This happens at every reset() and also at the distructor of the chunk dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go though this scenario: In the middle of a sweep, the worker is waiting inside add_chunk_data(), because the current buffer contains enough data. At this point, the user decide to exit current sweep and start a new one. Upon exiting, no more get_batch is called and the worker thread keeps waiting. At reset(), chunkDataset calls free_workers(), which calls join() to wait worker to finish. Because the worker is still in the cv wait and no notification is ever triggered, the join() will hang the program.
The original code called stop() in free_workers() before join, which breaks the wait, send the notification and resolve the hang.

}
}

/// Block the current thread until the workers finish execution and exit.
void free_workers() {
if (!quit_worker_.load()) {
quit_worker_ = true;
if(batch_buffer_){
batch_buffer_->stop();
}
quit_worker_ = true;
for (auto& worker_thread : preload_threads_) {
worker_thread.join();
}
Expand All @@ -416,7 +400,7 @@ class ChunkDataset final
ChunkReader chunk_reader_;

// chunk sampler to shuffle different chunks
samplers::LockedSampler<ChunkSamplerType> chunk_sampler_;
ChunkSamplerType chunk_sampler_;

// example sampler to shuffle examples in a specific chunk
ExampleSamplerType example_sampler_;
Expand All @@ -433,6 +417,13 @@ class ChunkDataset final

// indicate whether the worker thread can be teared down
std::atomic<bool> quit_worker_;

// keep track of running preloaders to notify batch buffer. A value 0
// indicates that the chunk loading is completed.
std::atomic<size_t> running_preloaders_;

// mutex to synchronize chunk sampler next() call.
std::mutex chunk_index_guard_;
};
} // namespace datasets
} // namespace data
Expand Down
35 changes: 0 additions & 35 deletions torch/csrc/api/include/torch/data/samplers/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,41 +42,6 @@ class Sampler {
TORCH_API virtual void load(serialize::InputArchive& archive) = 0;
};

/// Wraps a provided sampler to make it thread safe.
template <typename OriginalSampler>
class LockedSampler
: public Sampler<typename OriginalSampler::BatchRequestType> {
public:
using BatchRequestType = typename OriginalSampler::BatchRequestType;

explicit LockedSampler(OriginalSampler sampler) : sampler_(std::move(sampler)) {}

void reset(optional<size_t> new_size) override {
std::lock_guard<std::mutex> lock(this->mutex_);
sampler_.reset(new_size);
}

optional<BatchRequestType> next(size_t batch_size) override {
std::lock_guard<std::mutex> lock(this->mutex_);
return sampler_.next(batch_size);
}

void save(serialize::OutputArchive& archive) const override {
std::lock_guard<std::mutex> lock(this->mutex_);
sampler_.save(archive);
}

void load(serialize::InputArchive& archive) override {
std::lock_guard<std::mutex> lock(this->mutex_);
sampler_.load(archive);
}

private:
// member variable for multi-threading lock.
// declare it to be mutable for locking in const member function.
mutable std::mutex mutex_;
OriginalSampler sampler_;
};
} // namespace samplers
} // namespace data
} // namespace torch