Skip to content

Commit

Permalink
[MPS] Introduce torch.mps.Event() APIs
Browse files Browse the repository at this point in the history
- Implement MPSEventPool to recycle events and hook it into MPSAllocator
- Implement python bindings with torch.mps.Event class using the MPSEventPool backend.
The current member functions of the Event class are, record(), wait(), synchronize() and query()
- Add API to measure elapsed time between two event recordings
- Added test case to test_mps.py
- Replace PyLong_From with THPUtils_packUInt() to avoid overflow errors
  • Loading branch information
razarmehr committed May 24, 2023
1 parent 76af221 commit a870c7e
Show file tree
Hide file tree
Showing 15 changed files with 566 additions and 115 deletions.
27 changes: 23 additions & 4 deletions aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct TORCH_API MPSHooksInterface {
// this fails the implementation if MPSHooks functions are called, but
// MPS backend is not present.
#define FAIL_MPSHOOKS_FUNC(func) \
TORCH_CHECK(false, "Cannot execute ", func ,"() without MPS backend.");
TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");

virtual ~MPSHooksInterface() = default;

Expand Down Expand Up @@ -64,16 +64,35 @@ struct TORCH_API MPSHooksInterface {
virtual void setMemoryFraction(double /*ratio*/) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

virtual void profilerStopTrace() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual uint32_t acquireEvent(bool enable_timing) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void releaseEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void recordEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void waitForEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void synchronizeEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual bool queryEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

#undef FAIL_MPSHOOKS_FUNC
#undef FAIL_MPSHOOKS_FUNC
};

struct TORCH_API MPSHooksArgs {};
Expand Down
97 changes: 97 additions & 0 deletions aten/src/ATen/mps/MPSEvent.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright © 2023 Apple Inc.

#pragma once

#include <ATen/mps/MPSStream.h>
#include <ctime>
#include <stack>

namespace at::mps {

// NOTE: don't create instances of this class directly.
// Use MPSEventPool to acquire instances of MPSEvent.
class MPSEvent {
public:
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
~MPSEvent();

// records an event on the stream
void record(bool needsLock, bool syncEvent = false);
// makes all future work submitted to the stream wait for this event.
bool wait(bool needsLock, bool syncEvent = false);
// schedules a notifyListener callback for the event.
bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
// checks if events are already signaled.
bool query() const;
// blocks the CPU thread until all the GPU work that were scheduled
// prior to recording this event are completed.
bool synchronize();
// resets this event with new parameters in case it gets reused from the event pool
void reset(MPSStream* stream, bool enable_timing);
// returns the unique ID of the event instance
id_t getID() const { return m_id; }
// returns the completion timestamp of the event
uint64_t getCompletionTime() const { return m_completion_time; }

private:
id_t m_id;
// enables measuring the completion time of the notifyListener of this event
bool m_enable_timing;
uint64_t m_signalCounter = 0;
MPSStream* m_stream = nullptr;
MTLSharedEvent_t m_event = nullptr;
MTLSharedEventListener* m_listener = nullptr;
// used to sync the events created on this Stream with CPU
std::mutex m_cpu_sync_mutex{};
std::condition_variable m_cpu_sync_cv{};
// CondVar predicate to sync the events created on this Stream with CPU
bool m_cpu_sync_completed = false;
// used to compute elapsed time
uint64_t m_completion_time = 0;

void recordLocked(bool syncEvent);
bool waitLocked(bool syncEvent);
bool notifyLocked(MTLSharedEventNotificationBlock block);
static uint64_t getTime() {
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
}
};

typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;

class MPSEventPool {
public:
explicit MPSEventPool(MPSStream* default_stream);
~MPSEventPool();

MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
void emptyCache();

// these are mainly used for MPSHooks and torch.mps.Event() bindings
id_t acquireEvent(bool enable_timing);
void releaseEvent(id_t event_id);
void recordEvent(id_t event_id, bool syncEvent);
void waitForEvent(id_t event_id, bool syncEvent);
void synchronizeEvent(id_t event_id);
bool queryEvent(id_t event_id);
// returns elapsed time between two recorded events in milliseconds
double elapsedTime(id_t start_event_id, id_t end_event_id);

private:
MPSStream* m_default_stream = nullptr;
std::recursive_mutex m_mutex;
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
// dictionary to associate event IDs with event objects
// used to retain in-use events out of the pool
// for torch.mps.Event() bindings.
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
uint64_t m_event_counter = 0;
std::function<void(MPSEvent*)> m_default_deleter;

MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
};

// shared_ptr is used to get MPSEventPool destroyed after dependent instances
std::shared_ptr<MPSEventPool> getMPSEventPool();

} // namespace at::mps
245 changes: 245 additions & 0 deletions aten/src/ATen/mps/MPSEvent.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
// Copyright © 2023 Apple Inc.

#include <ATen/mps/MPSEvent.h>

namespace at::mps {

MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing)
: m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {}

MPSEvent::~MPSEvent() {
if (m_event) {
[m_event release];
m_event = nil;
}
if (m_listener) {
[m_listener release];
m_listener = nil;
}
}

void MPSEvent::recordLocked(bool syncEvent) {
// active encoders must end before encoding or waiting
m_stream->endKernelCoalescing();
++m_signalCounter;
if (m_enable_timing) {
notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
m_completion_time = getTime();
});
}
id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
[commandBuffer encodeSignalEvent:m_event value:m_signalCounter];
if (syncEvent) {
m_stream->synchronize(SyncType::COMMIT);
}
}

bool MPSEvent::waitLocked(bool syncEvent) {
// check if event is not recorded yet
if (m_event.signaledValue >= m_signalCounter) {
return false;
}
// active encoders must end before encoding or waiting
m_stream->endKernelCoalescing();
id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
[commandBuffer encodeWaitForEvent:m_event value:m_signalCounter];
if (syncEvent) {
m_stream->synchronize(SyncType::COMMIT);
}
return true;
}

bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) {
// check if event is not recorded yet
if (m_event.signaledValue >= m_signalCounter) {
return false;
}
if (!m_listener) {
m_listener = [[MTLSharedEventListener alloc] init];
}
[m_event notifyListener:m_listener atValue:m_signalCounter block:block];
return true;
}

void MPSEvent::record(bool needsLock, bool syncEvent) {
if (!needsLock) {
recordLocked(syncEvent);
return;
}
dispatch_sync(m_stream->queue(), ^() {
@autoreleasepool {
recordLocked(syncEvent);
}
});
}

bool MPSEvent::wait(bool needsLock, bool syncEvent) {
__block bool waited = false;
if (!needsLock) {
return waitLocked(syncEvent);
}
dispatch_sync(m_stream->queue(), ^() {
@autoreleasepool {
waited = waitLocked(syncEvent);
}
});
return waited;
}

bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) {
if (!needsLock) {
return notifyLocked(block);
}
__block bool scheduledNotify = false;
dispatch_sync(m_stream->queue(), ^() {
@autoreleasepool {
scheduledNotify = notifyLocked(block);
}
});
return scheduledNotify;
}

bool MPSEvent::synchronize() {
bool scheduledNotify = notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
m_completion_time = getTime();
std::lock_guard<std::mutex> lock(m_cpu_sync_mutex);
m_cpu_sync_completed = true;
m_cpu_sync_cv.notify_one();
});

if (scheduledNotify) {
std::unique_lock<std::mutex> lock(m_cpu_sync_mutex);
m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; });
m_cpu_sync_completed = false;
return true;
}
return false;
}

bool MPSEvent::query() const {
// return false if not recorded or signaled yet
return m_signalCounter && (m_event.signaledValue >= m_signalCounter);
}

void MPSEvent::reset(MPSStream* stream, bool enable_timing) {
if (stream != m_stream) {
m_signalCounter = 0;
m_event.signaledValue = 0;
m_stream = stream;
}
// reset record time
m_completion_time = 0;
m_enable_timing = enable_timing;
};

//-----------------------------------------------------------------
// MPSEventPool
//-----------------------------------------------------------------

MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) {
// default deleter to return the event back to pool after it gets released
m_default_deleter = [&](MPSEvent* event) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
m_pool.push(std::unique_ptr<MPSEvent>(event));
};
}

MPSEventPool::~MPSEventPool() {
emptyCache();
}

MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) {
if (!stream) {
stream = m_default_stream;
}
{
std::lock_guard<std::recursive_mutex> lock(m_mutex);
if (!m_pool.empty()) {
auto event = m_pool.top().release();
m_pool.pop();
event->reset(stream, enable_timing);
return MPSEventPtr(event, m_default_deleter);
}
}
auto new_event = std::make_unique<MPSEvent>(++m_event_counter, stream, enable_timing);
return MPSEventPtr(new_event.release(), m_default_deleter);
}

void MPSEventPool::emptyCache() {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
while (!m_pool.empty()) {
m_pool.pop();
}
}

id_t MPSEventPool::acquireEvent(bool enable_timing) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
MPSEventPtr event = acquireEvent(enable_timing, nullptr);
TORCH_INTERNAL_ASSERT(event);
id_t event_id = event->getID();
m_in_use_events.emplace(event_id, std::move(event));
return event_id;
}

void MPSEventPool::releaseEvent(id_t event_id) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
// returns the event back to the MPSEventPool
m_in_use_events.erase(event_id);
}

void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) {
MPSEvent* event = getInUseEvent(event_id);
event->record(/*needsLock*/ true, syncEvent);
}

void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) {
MPSEvent* event = getInUseEvent(event_id);
event->wait(/*needsLock*/ true, syncEvent);
}

void MPSEventPool::synchronizeEvent(id_t event_id) {
MPSEvent* event = getInUseEvent(event_id);
event->synchronize();
}

bool MPSEventPool::queryEvent(id_t event_id) {
MPSEvent* event = getInUseEvent(event_id);
return event->query();
}

double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) {
// first make sure notifyListeners are called to capture events' completion times
dispatch_sync(m_default_stream->queue(), ^() {
m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT);
});
std::lock_guard<std::recursive_mutex> lock(m_mutex);
MPSEvent* start_event = getInUseEvent(start_event_id, false);
MPSEvent* end_event = getInUseEvent(end_event_id, false);

const uint64_t start_time = start_event->getCompletionTime();
const uint64_t end_time = end_event->getCompletionTime();
TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'");
TORCH_CHECK(
end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id);
return double(end_time - start_time) * 1e-6;
}

MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) {
if (locked) {
m_mutex.lock();
}
TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
MPSEvent* event = m_in_use_events[event_id].get();
if (locked) {
m_mutex.unlock();
}
return event;
}

std::shared_ptr<MPSEventPool> getMPSEventPool() {
static std::shared_ptr<MPSEventPool> event_pool = std::make_shared<MPSEventPool>(getDefaultMPSStream());
return event_pool;
}

} // namespace at::mps

0 comments on commit a870c7e

Please sign in to comment.