Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Introduce torch.mps.Event() APIs #102121

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 23 additions & 4 deletions aten/src/ATen/detail/MPSHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct TORCH_API MPSHooksInterface {
// this fails the implementation if MPSHooks functions are called, but
// MPS backend is not present.
#define FAIL_MPSHOOKS_FUNC(func) \
TORCH_CHECK(false, "Cannot execute ", func ,"() without MPS backend.");
TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");

virtual ~MPSHooksInterface() = default;

Expand Down Expand Up @@ -64,16 +64,35 @@ struct TORCH_API MPSHooksInterface {
virtual void setMemoryFraction(double /*ratio*/) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

virtual void profilerStopTrace() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual uint32_t acquireEvent(bool enable_timing) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void releaseEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void recordEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void waitForEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual void synchronizeEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual bool queryEvent(uint32_t event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
FAIL_MPSHOOKS_FUNC(__func__);
}

#undef FAIL_MPSHOOKS_FUNC
#undef FAIL_MPSHOOKS_FUNC
};

struct TORCH_API MPSHooksArgs {};
Expand Down
100 changes: 100 additions & 0 deletions aten/src/ATen/mps/MPSEvent.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright © 2023 Apple Inc.

#pragma once

#include <ATen/mps/MPSStream.h>
#include <ctime>
#include <stack>

namespace at::mps {

// NOTE: don't create instances of this class directly.
// Use MPSEventPool to acquire instances of MPSEvent.
class MPSEvent {
public:
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
~MPSEvent();

// records an event on the stream
void record(bool needsLock, bool syncEvent = false);
// makes all future work submitted to the stream wait for this event.
bool wait(bool needsLock, bool syncEvent = false);
// schedules a notifyListener callback for the event.
bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
// checks if events are already signaled.
bool query() const;
// blocks the CPU thread until all the GPU work that were scheduled
// prior to recording this event are completed.
bool synchronize();
// resets this event with new parameters in case it gets reused from the event pool
void reset(MPSStream* stream, bool enable_timing);
// returns the unique ID of the event instance
id_t getID() const { return m_id; }
// returns the completion timestamp of the event
uint64_t getCompletionTime() const { return m_completion_time; }
// if already recorded, waits for cpu_sync_cv to be signaled
void waitForCpuSync();

private:
id_t m_id;
// enables measuring the completion time of the notifyListener of this event
bool m_enable_timing;
uint64_t m_signalCounter = 0;
MPSStream* m_stream = nullptr;
MTLSharedEvent_t m_event = nullptr;
MTLSharedEventListener* m_listener = nullptr;
// used to sync the events created on this Stream with CPU
std::mutex m_cpu_sync_mutex{};
std::condition_variable m_cpu_sync_cv{};
// CondVar predicate to sync the events created on this Stream with CPU
bool m_cpu_sync_completed = false;
// used to compute elapsed time
uint64_t m_completion_time = 0;

void recordLocked(bool syncEvent);
bool waitLocked(bool syncEvent);
bool notifyLocked(MTLSharedEventNotificationBlock block);
void notifyCpuSync();
static uint64_t getTime() {
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
}
};

typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;

class MPSEventPool {
public:
explicit MPSEventPool(MPSStream* default_stream);
~MPSEventPool();

MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
void emptyCache();

// these are mainly used for MPSHooks and torch.mps.Event() bindings
id_t acquireEvent(bool enable_timing);
void releaseEvent(id_t event_id);
void recordEvent(id_t event_id, bool syncEvent);
void waitForEvent(id_t event_id, bool syncEvent);
void synchronizeEvent(id_t event_id);
bool queryEvent(id_t event_id);
// returns elapsed time between two recorded events in milliseconds
double elapsedTime(id_t start_event_id, id_t end_event_id);

private:
MPSStream* m_default_stream = nullptr;
std::recursive_mutex m_mutex;
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
// dictionary to associate event IDs with event objects
// used to retain in-use events out of the pool
// for torch.mps.Event() bindings.
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
uint64_t m_event_counter = 0;
std::function<void(MPSEvent*)> m_default_deleter;

MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
};

// shared_ptr is used to get MPSEventPool destroyed after dependent instances
std::shared_ptr<MPSEventPool> getMPSEventPool();

} // namespace at::mps