-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MPS] Introduce torch.mps.Event() APIs
- Implement MPSEventPool to recycle events and hook it into MPSAllocator - Implement python bindings with torch.mps.Event class using the MPSEventPool backend. The current member functions of the Event class are, record(), wait(), synchronize() and query() - Add API to measure elapsed time between two event recordings - Added test case to test_mps.py - Replace PyLong_From with THPUtils_packUInt() to avoid overflow errors
- Loading branch information
Showing
15 changed files
with
566 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright © 2023 Apple Inc. | ||
|
||
#pragma once | ||
|
||
#include <ATen/mps/MPSStream.h> | ||
#include <ctime> | ||
#include <stack> | ||
|
||
namespace at::mps { | ||
|
||
// NOTE: don't create instances of this class directly. | ||
// Use MPSEventPool to acquire instances of MPSEvent. | ||
class MPSEvent { | ||
public: | ||
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing); | ||
~MPSEvent(); | ||
|
||
// records an event on the stream | ||
void record(bool needsLock, bool syncEvent = false); | ||
// makes all future work submitted to the stream wait for this event. | ||
bool wait(bool needsLock, bool syncEvent = false); | ||
// schedules a notifyListener callback for the event. | ||
bool notify(bool needsLock, MTLSharedEventNotificationBlock block); | ||
// checks if events are already signaled. | ||
bool query() const; | ||
// blocks the CPU thread until all the GPU work that were scheduled | ||
// prior to recording this event are completed. | ||
bool synchronize(); | ||
// resets this event with new parameters in case it gets reused from the event pool | ||
void reset(MPSStream* stream, bool enable_timing); | ||
// returns the unique ID of the event instance | ||
id_t getID() const { return m_id; } | ||
// returns the completion timestamp of the event | ||
uint64_t getCompletionTime() const { return m_completion_time; } | ||
|
||
private: | ||
id_t m_id; | ||
// enables measuring the completion time of the notifyListener of this event | ||
bool m_enable_timing; | ||
uint64_t m_signalCounter = 0; | ||
MPSStream* m_stream = nullptr; | ||
MTLSharedEvent_t m_event = nullptr; | ||
MTLSharedEventListener* m_listener = nullptr; | ||
// used to sync the events created on this Stream with CPU | ||
std::mutex m_cpu_sync_mutex{}; | ||
std::condition_variable m_cpu_sync_cv{}; | ||
// CondVar predicate to sync the events created on this Stream with CPU | ||
bool m_cpu_sync_completed = false; | ||
// used to compute elapsed time | ||
uint64_t m_completion_time = 0; | ||
|
||
void recordLocked(bool syncEvent); | ||
bool waitLocked(bool syncEvent); | ||
bool notifyLocked(MTLSharedEventNotificationBlock block); | ||
static uint64_t getTime() { | ||
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); | ||
} | ||
}; | ||
|
||
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr; | ||
|
||
class MPSEventPool { | ||
public: | ||
explicit MPSEventPool(MPSStream* default_stream); | ||
~MPSEventPool(); | ||
|
||
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream); | ||
void emptyCache(); | ||
|
||
// these are mainly used for MPSHooks and torch.mps.Event() bindings | ||
id_t acquireEvent(bool enable_timing); | ||
void releaseEvent(id_t event_id); | ||
void recordEvent(id_t event_id, bool syncEvent); | ||
void waitForEvent(id_t event_id, bool syncEvent); | ||
void synchronizeEvent(id_t event_id); | ||
bool queryEvent(id_t event_id); | ||
// returns elapsed time between two recorded events in milliseconds | ||
double elapsedTime(id_t start_event_id, id_t end_event_id); | ||
|
||
private: | ||
MPSStream* m_default_stream = nullptr; | ||
std::recursive_mutex m_mutex; | ||
std::stack<std::unique_ptr<MPSEvent>> m_pool{}; | ||
// dictionary to associate event IDs with event objects | ||
// used to retain in-use events out of the pool | ||
// for torch.mps.Event() bindings. | ||
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{}; | ||
uint64_t m_event_counter = 0; | ||
std::function<void(MPSEvent*)> m_default_deleter; | ||
|
||
MPSEvent* getInUseEvent(id_t event_id, bool locked = true); | ||
}; | ||
|
||
// shared_ptr is used to get MPSEventPool destroyed after dependent instances | ||
std::shared_ptr<MPSEventPool> getMPSEventPool(); | ||
|
||
} // namespace at::mps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
// Copyright © 2023 Apple Inc. | ||
|
||
#include <ATen/mps/MPSEvent.h> | ||
|
||
namespace at::mps { | ||
|
||
MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing) | ||
: m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {} | ||
|
||
MPSEvent::~MPSEvent() { | ||
if (m_event) { | ||
[m_event release]; | ||
m_event = nil; | ||
} | ||
if (m_listener) { | ||
[m_listener release]; | ||
m_listener = nil; | ||
} | ||
} | ||
|
||
void MPSEvent::recordLocked(bool syncEvent) { | ||
// active encoders must end before encoding or waiting | ||
m_stream->endKernelCoalescing(); | ||
++m_signalCounter; | ||
if (m_enable_timing) { | ||
notifyLocked(^(id<MTLSharedEvent>, uint64_t) { | ||
m_completion_time = getTime(); | ||
}); | ||
} | ||
id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer(); | ||
[commandBuffer encodeSignalEvent:m_event value:m_signalCounter]; | ||
if (syncEvent) { | ||
m_stream->synchronize(SyncType::COMMIT); | ||
} | ||
} | ||
|
||
bool MPSEvent::waitLocked(bool syncEvent) { | ||
// check if event is not recorded yet | ||
if (m_event.signaledValue >= m_signalCounter) { | ||
return false; | ||
} | ||
// active encoders must end before encoding or waiting | ||
m_stream->endKernelCoalescing(); | ||
id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer(); | ||
[commandBuffer encodeWaitForEvent:m_event value:m_signalCounter]; | ||
if (syncEvent) { | ||
m_stream->synchronize(SyncType::COMMIT); | ||
} | ||
return true; | ||
} | ||
|
||
bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) { | ||
// check if event is not recorded yet | ||
if (m_event.signaledValue >= m_signalCounter) { | ||
return false; | ||
} | ||
if (!m_listener) { | ||
m_listener = [[MTLSharedEventListener alloc] init]; | ||
} | ||
[m_event notifyListener:m_listener atValue:m_signalCounter block:block]; | ||
return true; | ||
} | ||
|
||
void MPSEvent::record(bool needsLock, bool syncEvent) { | ||
if (!needsLock) { | ||
recordLocked(syncEvent); | ||
return; | ||
} | ||
dispatch_sync(m_stream->queue(), ^() { | ||
@autoreleasepool { | ||
recordLocked(syncEvent); | ||
} | ||
}); | ||
} | ||
|
||
bool MPSEvent::wait(bool needsLock, bool syncEvent) { | ||
__block bool waited = false; | ||
if (!needsLock) { | ||
return waitLocked(syncEvent); | ||
} | ||
dispatch_sync(m_stream->queue(), ^() { | ||
@autoreleasepool { | ||
waited = waitLocked(syncEvent); | ||
} | ||
}); | ||
return waited; | ||
} | ||
|
||
bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) { | ||
if (!needsLock) { | ||
return notifyLocked(block); | ||
} | ||
__block bool scheduledNotify = false; | ||
dispatch_sync(m_stream->queue(), ^() { | ||
@autoreleasepool { | ||
scheduledNotify = notifyLocked(block); | ||
} | ||
}); | ||
return scheduledNotify; | ||
} | ||
|
||
bool MPSEvent::synchronize() { | ||
bool scheduledNotify = notifyLocked(^(id<MTLSharedEvent>, uint64_t) { | ||
m_completion_time = getTime(); | ||
std::lock_guard<std::mutex> lock(m_cpu_sync_mutex); | ||
m_cpu_sync_completed = true; | ||
m_cpu_sync_cv.notify_one(); | ||
}); | ||
|
||
if (scheduledNotify) { | ||
std::unique_lock<std::mutex> lock(m_cpu_sync_mutex); | ||
m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; }); | ||
m_cpu_sync_completed = false; | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
bool MPSEvent::query() const { | ||
// return false if not recorded or signaled yet | ||
return m_signalCounter && (m_event.signaledValue >= m_signalCounter); | ||
} | ||
|
||
void MPSEvent::reset(MPSStream* stream, bool enable_timing) { | ||
if (stream != m_stream) { | ||
m_signalCounter = 0; | ||
m_event.signaledValue = 0; | ||
m_stream = stream; | ||
} | ||
// reset record time | ||
m_completion_time = 0; | ||
m_enable_timing = enable_timing; | ||
}; | ||
|
||
//----------------------------------------------------------------- | ||
// MPSEventPool | ||
//----------------------------------------------------------------- | ||
|
||
MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) { | ||
// default deleter to return the event back to pool after it gets released | ||
m_default_deleter = [&](MPSEvent* event) { | ||
std::lock_guard<std::recursive_mutex> lock(m_mutex); | ||
m_pool.push(std::unique_ptr<MPSEvent>(event)); | ||
}; | ||
} | ||
|
||
MPSEventPool::~MPSEventPool() { | ||
emptyCache(); | ||
} | ||
|
||
MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) { | ||
if (!stream) { | ||
stream = m_default_stream; | ||
} | ||
{ | ||
std::lock_guard<std::recursive_mutex> lock(m_mutex); | ||
if (!m_pool.empty()) { | ||
auto event = m_pool.top().release(); | ||
m_pool.pop(); | ||
event->reset(stream, enable_timing); | ||
return MPSEventPtr(event, m_default_deleter); | ||
} | ||
} | ||
auto new_event = std::make_unique<MPSEvent>(++m_event_counter, stream, enable_timing); | ||
return MPSEventPtr(new_event.release(), m_default_deleter); | ||
} | ||
|
||
void MPSEventPool::emptyCache() { | ||
std::lock_guard<std::recursive_mutex> lock(m_mutex); | ||
while (!m_pool.empty()) { | ||
m_pool.pop(); | ||
} | ||
} | ||
|
||
id_t MPSEventPool::acquireEvent(bool enable_timing) { | ||
std::lock_guard<std::recursive_mutex> lock(m_mutex); | ||
MPSEventPtr event = acquireEvent(enable_timing, nullptr); | ||
TORCH_INTERNAL_ASSERT(event); | ||
id_t event_id = event->getID(); | ||
m_in_use_events.emplace(event_id, std::move(event)); | ||
return event_id; | ||
} | ||
|
||
void MPSEventPool::releaseEvent(id_t event_id) { | ||
std::lock_guard<std::recursive_mutex> lock(m_mutex); | ||
TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id); | ||
// returns the event back to the MPSEventPool | ||
m_in_use_events.erase(event_id); | ||
} | ||
|
||
void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) { | ||
MPSEvent* event = getInUseEvent(event_id); | ||
event->record(/*needsLock*/ true, syncEvent); | ||
} | ||
|
||
void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) { | ||
MPSEvent* event = getInUseEvent(event_id); | ||
event->wait(/*needsLock*/ true, syncEvent); | ||
} | ||
|
||
void MPSEventPool::synchronizeEvent(id_t event_id) { | ||
MPSEvent* event = getInUseEvent(event_id); | ||
event->synchronize(); | ||
} | ||
|
||
bool MPSEventPool::queryEvent(id_t event_id) { | ||
MPSEvent* event = getInUseEvent(event_id); | ||
return event->query(); | ||
} | ||
|
||
double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) { | ||
// first make sure notifyListeners are called to capture events' completion times | ||
dispatch_sync(m_default_stream->queue(), ^() { | ||
m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT); | ||
}); | ||
std::lock_guard<std::recursive_mutex> lock(m_mutex); | ||
MPSEvent* start_event = getInUseEvent(start_event_id, false); | ||
MPSEvent* end_event = getInUseEvent(end_event_id, false); | ||
|
||
const uint64_t start_time = start_event->getCompletionTime(); | ||
const uint64_t end_time = end_event->getCompletionTime(); | ||
TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'"); | ||
TORCH_CHECK( | ||
end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id); | ||
return double(end_time - start_time) * 1e-6; | ||
} | ||
|
||
MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) { | ||
if (locked) { | ||
m_mutex.lock(); | ||
} | ||
TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id); | ||
MPSEvent* event = m_in_use_events[event_id].get(); | ||
if (locked) { | ||
m_mutex.unlock(); | ||
} | ||
return event; | ||
} | ||
|
||
std::shared_ptr<MPSEventPool> getMPSEventPool() { | ||
static std::shared_ptr<MPSEventPool> event_pool = std::make_shared<MPSEventPool>(getDefaultMPSStream()); | ||
return event_pool; | ||
} | ||
|
||
} // namespace at::mps |
Oops, something went wrong.