Skip to content

Commit

Permalink
Update on "Extend SampleInput str representation with tensor data."
Browse files Browse the repository at this point in the history
As in the title. The aim of this addition is to make debugging certain CI failures (that cannot be reproduced locally) easier. For instance, currently we see messages like
```
Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(20,), device="cuda:0", dtype=torch.float64], args=(), kwargs={}, broadcasts_input=False, name='')
```
that is not really useful (as all those sample parameters can often be detected by other means) without showing actual sample data. The sample data can then be related to the `index` part in the error messages like:
```
Mismatched elements: 2 / 20 (10.0%)
Greatest absolute difference: nan at index (10,) (up to 1e-05 allowed)
Greatest relative difference: nan at index (10,) (up to 1e-07 allowed)
```

As an example of usefulness of this PR, consider the following failure message:
```
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 ('RERUN', {'yellow': True}) [1.5510s] [ 70%]
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 ('RERUN', {'yellow': True}) [0.0473s] [ 70%]
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 FAILED [0.0493s] [ 70%]

==================================== RERUNS ====================================
__ TestInductorOpInfoCPU.test_comprehensive_polygamma_polygamma_n_0_cpu_int32 __
Traceback (most recent call last):
<snip>
AssertionError: Tensor-likes are not close!

Mismatched elements: 9 / 25 (36.0%)
Greatest absolute difference: inf at index (0, 0) (up to 1e-05 allowed), inf vs 20177651499008.0
Greatest relative difference: inf at index (0, 0) (up to 1.3e-06 allowed)

The above exception was the direct cause of the following exception:

<snip>
Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(5, 5), device="cpu", dtype=torch.int32, data=[-8, 6, 9, 0, 0, 5, 5, 7, 6, 5, 1, -5, 2, -1, 8, -4, 0, -6, 3, -5]], args=(1), kwargs={}, broadcasts_input=False, name='')
```
from which we learn that `torch.polygamma` result is actually correct because `polygamma(0, -8) -> inf` while the used reference value (20177651499008.0) is wrong (see #106692 for more details).





[ghstack-poisoned]
  • Loading branch information
pearu committed Feb 10, 2024
2 parents fac43ea + e2800ca commit 4504eee
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 115 deletions.
92 changes: 50 additions & 42 deletions c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ struct BlockPool {
struct ExpandableSegment;

struct Block {
int device; // gpu
c10::DeviceIndex device; // gpu
cudaStream_t stream; // allocation stream
stream_set stream_uses; // streams on which the block was used
size_t size; // block size in bytes
Expand All @@ -229,7 +229,7 @@ struct Block {
ExpandableSegment* expandable_segment_{nullptr};

Block(
int device,
c10::DeviceIndex device,
cudaStream_t stream,
size_t size,
BlockPool* pool,
Expand All @@ -244,7 +244,7 @@ struct Block {
gc_count_base(0) {}

// constructor for search key
Block(int device, cudaStream_t stream, size_t size)
Block(c10::DeviceIndex device, cudaStream_t stream, size_t size)
: device(device),
stream(stream),
stream_uses(),
Expand Down Expand Up @@ -383,10 +383,10 @@ Instead these mapping have to be done manually. The allocator now has an

struct ExpandableSegment {
ExpandableSegment(
int device,
c10::DeviceIndex device,
cudaStream_t stream,
size_t size,
std::vector<int> peers)
std::vector<c10::DeviceIndex> peers)
: device_(device),
stream_(stream),
max_handles_(0),
Expand Down Expand Up @@ -473,7 +473,7 @@ struct ExpandableSegment {
return max_handles_ * segment_size_;
}

void addPeer(int device) {
void addPeer(c10::DeviceIndex device) {
peers_.push_back(device);
forEachAllocatedRange(
[&](size_t begin, size_t end) { setAccess(device, begin, end); });
Expand All @@ -487,7 +487,7 @@ struct ExpandableSegment {
}

private:
void setAccess(int device, size_t begin, size_t end) {
void setAccess(c10::DeviceIndex device, size_t begin, size_t end) {
CUmemAccessDesc desc;
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
desc.location.id = device;
Expand Down Expand Up @@ -545,23 +545,23 @@ struct ExpandableSegment {
return SegmentRange(
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
int device_;
c10::DeviceIndex device_;
cudaStream_t stream_;
CUdeviceptr ptr_{};
size_t max_handles_;
size_t segment_size_;
std::vector<c10::optional<CUmemGenericAllocationHandle>> handles_;
// devices on which this memory should be mapped in addition
// to the device where the physical memory lives (device_).
std::vector<int> peers_;
std::vector<c10::DeviceIndex> peers_;
};
#else
struct ExpandableSegment {
ExpandableSegment(
int device,
c10::DeviceIndex device,
cudaStream_t stream,
size_t size,
const std::vector<int>& peers) {
const std::vector<c10::DeviceIndex>& peers) {
TORCH_INTERNAL_ASSERT(false, "expandable segment not supported");
}
SegmentRange map(SegmentRange range) {
Expand All @@ -576,15 +576,15 @@ struct ExpandableSegment {
size_t size() const {
return 0;
}
void addPeer(int device) {}
void addPeer(c10::DeviceIndex device) {}
};
#endif

// BlockState, BlockPoolState, and PrivatePoolState contain the information
// needed to reconstruct a private pool to a previous state. See note
// [Checkpointing PrivatePoolState]
struct BlockState {
int device = 0;
c10::DeviceIndex device = 0;
cudaStream_t stream = nullptr;
stream_set stream_uses = {};
size_t size = 0;
Expand Down Expand Up @@ -638,7 +638,7 @@ static bool BlockComparatorAddress(const Block* a, const Block* b) {

struct AllocParams {
AllocParams(
int device,
c10::DeviceIndex device,
size_t size,
cudaStream_t stream,
BlockPool* pool,
Expand All @@ -650,7 +650,7 @@ struct AllocParams {
block(nullptr),
err(cudaSuccess) {}

int device() const {
c10::DeviceIndex device() const {
return search_key.device;
}
cudaStream_t stream() const {
Expand Down Expand Up @@ -680,7 +680,7 @@ class EventPool {
// TODO: Explicit device count
EventPool() : pools_(at::cuda::device_count()) {}

Event get(int device) {
Event get(c10::DeviceIndex device) {
TORCH_INTERNAL_ASSERT(0 <= device);
TORCH_INTERNAL_ASSERT(device < static_cast<int>(pools_.size()));
auto& pool = pools_[device];
Expand Down Expand Up @@ -804,7 +804,7 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
} // anonymous namespace
} // namespace Native

static std::string reportProcessMemoryInfo(int device) {
static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
#ifdef PYTORCH_C10_DRIVER_API_SUPPORTED
void* nvml_handle = DriverAPI::get_nvml_handle();
if (!nvml_handle) {
Expand Down Expand Up @@ -902,7 +902,7 @@ class DeviceCachingAllocator {

// all live expandable segments
std::vector<ExpandableSegment*> expandable_segments_;
std::vector<int> devices_with_peer_access_;
std::vector<c10::DeviceIndex> devices_with_peer_access_;

bool set_fraction = false;

Expand Down Expand Up @@ -1007,7 +1007,10 @@ class DeviceCachingAllocator {
// All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method.

Block* malloc(int device, size_t orig_size, cudaStream_t stream) {
Block* malloc(
c10::DeviceIndex device,
size_t orig_size,
cudaStream_t stream) {
// done outside the lock because we don't know what locks the recorder needs
// to have...
auto context = maybeGatherContext(RecordContext::STATE);
Expand Down Expand Up @@ -1095,7 +1098,7 @@ class DeviceCachingAllocator {
.current,
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current,
c10::Device(c10::DeviceType::CUDA, static_cast<DeviceIndex>(device)));
c10::Device(c10::DeviceType::CUDA, device));

auto allocated_bytes =
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
Expand Down Expand Up @@ -1860,7 +1863,7 @@ class DeviceCachingAllocator {
}
}

void addPeerAccess(int dev_to_access) {
void addPeerAccess(c10::DeviceIndex dev_to_access) {
if (std::find(
devices_with_peer_access_.begin(),
devices_with_peer_access_.end(),
Expand Down Expand Up @@ -1927,7 +1930,7 @@ class DeviceCachingAllocator {
// where there is enough free address space to fit size
// may be composed of free and unmapped segments
Block* find_expandable_block(
int device,
c10::DeviceIndex device,
cudaStream_t stream,
BlockPool* pool,
size_t size) {
Expand Down Expand Up @@ -2034,7 +2037,7 @@ class DeviceCachingAllocator {
}

Block* try_allocate_expandable_block(
int device,
c10::DeviceIndex device,
cudaStream_t stream,
BlockPool* pool,
size_t size,
Expand Down Expand Up @@ -2660,7 +2663,7 @@ class DeviceCachingAllocator {
}
}

EventPool::Event create_event_internal(int idx) {
EventPool::Event create_event_internal(c10::DeviceIndex idx) {
// Leak the event pool to avoid shutdown issues.
static auto* event_pool = new EventPool();
return event_pool->get(idx);
Expand Down Expand Up @@ -2701,8 +2704,7 @@ class DeviceCachingAllocator {
for (auto& stream : streams) {
C10_CUDA_CHECK(c10::cuda::SetDevice(stream.device_index()));

EventPool::Event event =
create_event_internal(static_cast<int>(stream.device_index()));
EventPool::Event event = create_event_internal(stream.device_index());
C10_CUDA_CHECK(cudaEventRecord(*event, stream.stream()));

block->event_count++;
Expand Down Expand Up @@ -2780,7 +2782,7 @@ class DeviceCachingAllocator {
int64_t addr,
size_t size,
cudaStream_t stream,
int device,
c10::DeviceIndex device,
std::shared_ptr<GatheredContext> context) {
if (!record_history && !trace_trackers_.size())
return;
Expand Down Expand Up @@ -2895,7 +2897,11 @@ class NativeCachingAllocator : public CUDAAllocator {
}

/** allocates a block which is safe to use from the provided stream */
void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {
void malloc(
void** devPtr,
c10::DeviceIndex device,
size_t size,
cudaStream_t stream) {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
Expand Down Expand Up @@ -2927,7 +2933,7 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[block->device]->free(block);
}

void setMemoryFraction(double fraction, int device) override {
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
Expand Down Expand Up @@ -2960,7 +2966,7 @@ class NativeCachingAllocator : public CUDAAllocator {
}

bool checkPoolLiveAllocations(
int device,
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) override {
return device_allocator[device]->checkPoolLiveAllocations(
Expand Down Expand Up @@ -3029,8 +3035,9 @@ class NativeCachingAllocator : public CUDAAllocator {
return result;
}

std::shared_ptr<AllocatorState> getCheckpointState(int device, MempoolId_t id)
override {
std::shared_ptr<AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) override {
return device_allocator[device]->getCheckpointState(id);
}

Expand All @@ -3047,7 +3054,7 @@ class NativeCachingAllocator : public CUDAAllocator {
* functions for all allocated blocks in the new checkpoint state.
*/
CheckpointDelta setCheckpointPoolState(
int device,
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> as) override {
std::shared_ptr<PrivatePoolState> pps =
std::dynamic_pointer_cast<PrivatePoolState>(as);
Expand Down Expand Up @@ -3117,10 +3124,10 @@ class NativeCachingAllocator : public CUDAAllocator {
return &local_raw_delete;
}
}
void cacheInfo(int dev_id, size_t* largestBlock) override {
device_allocator[dev_id]->cacheInfo(largestBlock);
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
device_allocator[device]->cacheInfo(largestBlock);
}
void assertValidDevice(int device) {
void assertValidDevice(c10::DeviceIndex device) {
const auto device_num = device_allocator.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
Expand All @@ -3129,36 +3136,37 @@ class NativeCachingAllocator : public CUDAAllocator {
": did you call init?");
}

DeviceStats getDeviceStats(int device) override {
DeviceStats getDeviceStats(c10::DeviceIndex device) override {
assertValidDevice(device);
return device_allocator[device]->getStats();
}

void resetAccumulatedStats(int device) override {
void resetAccumulatedStats(c10::DeviceIndex device) override {
assertValidDevice(device);
device_allocator[device]->resetAccumulatedStats();
}

void resetPeakStats(int device) override {
void resetPeakStats(c10::DeviceIndex device) override {
assertValidDevice(device);
device_allocator[device]->resetPeakStats();
}
// CUDAGraph interactions
void beginAllocateToPool(
int device,
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) override {
assertValidDevice(device);
device_allocator[device]->beginAllocateToPool(
std::move(mempool_id), std::move(filter));
}

void endAllocateToPool(int device, MempoolId_t mempool_id) override {
void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id)
override {
assertValidDevice(device);
device_allocator[device]->endAllocateToPool(mempool_id);
}

void releasePool(int device, MempoolId_t mempool_id) override {
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
assertValidDevice(device);
device_allocator[device]->releasePool(std::move(mempool_id));
}
Expand Down
Loading

0 comments on commit 4504eee

Please sign in to comment.