diff --git a/libkineto/include/GenericTraceActivity.h b/libkineto/include/GenericTraceActivity.h index 712f5deba..8b52a5325 100644 --- a/libkineto/include/GenericTraceActivity.h +++ b/libkineto/include/GenericTraceActivity.h @@ -20,6 +20,7 @@ namespace libkineto { // Link type, used in GenericTraceActivity.flow.type constexpr unsigned int kLinkFwdBwd = 1; +constexpr unsigned int kLinkAsyncCpuGpu = 2; // @lint-ignore-every CLANGTIDY cppcoreguidelines-non-private-member-variables-in-classes // @lint-ignore-every CLANGTIDY cppcoreguidelines-pro-type-member-init @@ -57,6 +58,10 @@ class GenericTraceActivity : public ITraceActivity { return activityType; } + const ITraceActivity* linkedActivity() const override { + return nullptr; + } + int flowType() const override { return flow.type; } @@ -65,12 +70,12 @@ class GenericTraceActivity : public ITraceActivity { return flow.id; } - const std::string name() const override { - return activityName; + bool flowStart() const override { + return flow.start; } - const ITraceActivity* linkedActivity() const override { - return flow.linkedActivity; + const std::string name() const override { + return activityName; } const TraceSpan* traceSpan() const override { @@ -103,13 +108,13 @@ class GenericTraceActivity : public ITraceActivity { ActivityType activityType; std::string activityName; struct Flow { - Flow(): linkedActivity(nullptr), id(0), type(0) {} - ITraceActivity* linkedActivity; // Only set in destination side. + Flow(): id(0), type(0), start(0) {} // Ids must be unique within each type - uint32_t id : 28; + uint32_t id : 27; // Type will be used to connect flows between profilers, as // well as look up flow information (name etc) uint32_t type : 4; + uint32_t start : 1; } flow; private: diff --git a/libkineto/include/ITraceActivity.h b/libkineto/include/ITraceActivity.h index cd07f1f8e..2ccaba049 100644 --- a/libkineto/include/ITraceActivity.h +++ b/libkineto/include/ITraceActivity.h @@ -33,6 +33,7 @@ struct ITraceActivity { // Part of a flow, identified by flow id and type virtual int flowType() const = 0; virtual int flowId() const = 0; + virtual bool flowStart() const = 0; virtual ActivityType type() const = 0; virtual const std::string name() const = 0; // Optional linked activity diff --git a/libkineto/include/TraceSpan.h b/libkineto/include/TraceSpan.h index f33e18fb9..77145af47 100644 --- a/libkineto/include/TraceSpan.h +++ b/libkineto/include/TraceSpan.h @@ -36,8 +36,6 @@ struct TraceSpan { std::string name; // Prefix used to distinguish trace spans on the same timeline std::string prefix; - // Tracked by profiler for iteration trigger - bool tracked{false}; }; } // namespace libkineto diff --git a/libkineto/src/ActivityBuffers.h b/libkineto/src/ActivityBuffers.h index e482be217..ce3acc7d5 100644 --- a/libkineto/src/ActivityBuffers.h +++ b/libkineto/src/ActivityBuffers.h @@ -19,6 +19,16 @@ namespace KINETO_NAMESPACE { struct ActivityBuffers { std::list> cpu; std::unique_ptr gpu; + + // Add a wrapper object to the underlying struct stored in the buffer + template + const ITraceActivity& addActivityWrapper(const T& act) { + wrappers_.push_back(std::make_unique(act)); + return *wrappers_.back().get(); + } + + private: + std::vector> wrappers_; }; } // namespace KINETO_NAMESPACE diff --git a/libkineto/src/CuptiActivity.h b/libkineto/src/CuptiActivity.h index 053d42eb4..f1ce8bfa6 100644 --- a/libkineto/src/CuptiActivity.h +++ b/libkineto/src/CuptiActivity.h @@ -30,7 +30,7 @@ struct TraceSpan; // Abstract base class, templated on Cupti activity type template struct CuptiActivity : public ITraceActivity { - explicit CuptiActivity(const T* activity, const ITraceActivity& linked) + explicit CuptiActivity(const T* activity, const ITraceActivity* linked) : activity_(*activity), linked_(linked) {} int64_t timestamp() const override { return nsToUs(unixEpochTimestamp(activity_.start)); @@ -39,27 +39,28 @@ struct CuptiActivity : public ITraceActivity { return nsToUs(activity_.end - activity_.start); } int64_t correlationId() const override {return activity_.correlationId;} - int flowType() const override {return 0;} - int flowId() const override {return 0;} + const ITraceActivity* linkedActivity() const override {return linked_;} + int flowType() const override {return kLinkAsyncCpuGpu;} + int flowId() const override {return correlationId();} const T& raw() const {return activity_;} - const ITraceActivity* linkedActivity() const override {return &linked_;} const TraceSpan* traceSpan() const override {return nullptr;} protected: const T& activity_; - const ITraceActivity& linked_; + const ITraceActivity* linked_{nullptr}; }; // CUpti_ActivityAPI - CUDA runtime activities struct RuntimeActivity : public CuptiActivity { explicit RuntimeActivity( const CUpti_ActivityAPI* activity, - const ITraceActivity& linked, + const ITraceActivity* linked, int32_t threadId) : CuptiActivity(activity, linked), threadId_(threadId) {} int64_t deviceId() const override {return processId();} int64_t resourceId() const override {return threadId_;} ActivityType type() const override {return ActivityType::CUDA_RUNTIME;} + bool flowStart() const override; const std::string name() const override {return runtimeCbidName(activity_.cbid);} void log(ActivityLogger& logger) const override; const std::string metadataJson() const override; @@ -72,11 +73,12 @@ struct RuntimeActivity : public CuptiActivity { // Can also be instantiated directly. template struct GpuActivity : public CuptiActivity { - explicit GpuActivity(const T* activity, const ITraceActivity& linked) + explicit GpuActivity(const T* activity, const ITraceActivity* linked) : CuptiActivity(activity, linked) {} int64_t deviceId() const override {return raw().deviceId;} int64_t resourceId() const override {return raw().streamId;} ActivityType type() const override; + bool flowStart() const override {return false;} const std::string name() const override; void log(ActivityLogger& logger) const override; const std::string metadataJson() const override; diff --git a/libkineto/src/CuptiActivity.tpp b/libkineto/src/CuptiActivity.tpp index b71f419b5..2884fcc77 100644 --- a/libkineto/src/CuptiActivity.tpp +++ b/libkineto/src/CuptiActivity.tpp @@ -67,7 +67,7 @@ inline ActivityType GpuActivity::type() const { } inline void RuntimeActivity::log(ActivityLogger& logger) const { - logger.handleRuntimeActivity(*this); + logger.handleGenericActivity(*this); } template @@ -75,8 +75,20 @@ inline void GpuActivity::log(ActivityLogger& logger) const { logger.handleGpuActivity(*this); } +inline bool RuntimeActivity::flowStart() const { + return activity_.cbid == CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 || + (activity_.cbid >= CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 && + activity_.cbid <= CUPTI_RUNTIME_TRACE_CBID_cudaMemset2DAsync_v3020) || + activity_.cbid == + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000 || + activity_.cbid == + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000; +} + inline const std::string RuntimeActivity::metadataJson() const { - return ""; + return fmt::format(R"JSON( + "cbid": {}, "correlation": {})JSON", + activity_.cbid, activity_.correlationId); } template diff --git a/libkineto/src/CuptiActivityBuffer.h b/libkineto/src/CuptiActivityBuffer.h index d444b9399..3fe229456 100644 --- a/libkineto/src/CuptiActivityBuffer.h +++ b/libkineto/src/CuptiActivityBuffer.h @@ -15,6 +15,8 @@ #include #include +#include "ITraceActivity.h" + namespace KINETO_NAMESPACE { class CuptiActivityBuffer { @@ -44,6 +46,8 @@ class CuptiActivityBuffer { std::vector buf_; size_t size_; + + std::vector> wrappers_; }; using CuptiActivityBufferMap = diff --git a/libkineto/src/CuptiActivityProfiler.cpp b/libkineto/src/CuptiActivityProfiler.cpp index 32317449d..c52b451b3 100644 --- a/libkineto/src/CuptiActivityProfiler.cpp +++ b/libkineto/src/CuptiActivityProfiler.cpp @@ -143,9 +143,8 @@ void CuptiActivityProfiler::processCpuTrace( if (config_->selectedActivityTypes().count(act.type())) { act.log(logger); } - // Stash event so we can look it up later when processing GPU trace - externalEvents_.insertEvent(&act); clientActivityTraceMap_[act.correlationId()] = &span_pair; + activityMap_[act.correlationId()] = &act; } logger.handleTraceSpan(cpu_span); } @@ -153,40 +152,10 @@ void CuptiActivityProfiler::processCpuTrace( #ifdef HAS_CUPTI inline void CuptiActivityProfiler::handleCorrelationActivity( const CUpti_ActivityExternalCorrelation* correlation) { - externalEvents_.addCorrelation( - correlation->externalId, correlation->correlationId); + cpuCorrelationMap_[correlation->correlationId] = correlation->externalId; } #endif // HAS_CUPTI -const libkineto::GenericTraceActivity& -CuptiActivityProfiler::ExternalEventMap::correlatedActivity(uint32_t id) { - static const libkineto::GenericTraceActivity nullOp_( - defaultTraceSpan().first, ActivityType::CPU_OP, "NULL"); - - auto* res = events_[correlationMap_[id]]; - if (res == nullptr) { - // Entry may be missing because cpu trace hasn't been processed yet - // Insert a dummy element so that we can check for this in insertEvent - events_[correlationMap_[id]] = &nullOp_; - res = &nullOp_; - } - return *res; -} - -void CuptiActivityProfiler::ExternalEventMap::insertEvent( - const libkineto::GenericTraceActivity* op) { - if (events_[op->correlationId()] != nullptr) { - LOG_EVERY_N(WARNING, 100) - << "Events processed out of order - link will be missing"; - } - events_[op->correlationId()] = op; -} - -void CuptiActivityProfiler::ExternalEventMap::addCorrelation( - uint64_t external_id, uint32_t cuda_id) { - correlationMap_[cuda_id] = external_id; -} - static GenericTraceActivity createUserGpuSpan( const libkineto::ITraceActivity& cpuTraceActivity, const libkineto::ITraceActivity& gpuTraceActivity) { @@ -206,6 +175,10 @@ static GenericTraceActivity createUserGpuSpan( void CuptiActivityProfiler::GpuUserEventMap::insertOrExtendEvent( const ITraceActivity&, const ITraceActivity& gpuActivity) { + if (!gpuActivity.linkedActivity()) { + VLOG(0) << "Missing linked activity"; + return; + } const ITraceActivity& cpuActivity = *gpuActivity.linkedActivity(); StreamKey key(gpuActivity.deviceId(), gpuActivity.resourceId()); CorrelationSpanMap& correlationSpanMap = streamSpanMap_[key]; @@ -252,7 +225,7 @@ inline bool CuptiActivityProfiler::outOfRange(const ITraceActivity& act) { return out_of_range; } -inline void CuptiActivityProfiler::handleRuntimeActivity( +void CuptiActivityProfiler::handleRuntimeActivity( const CUpti_ActivityAPI* activity, ActivityLogger* logger) { // Some CUDA calls that are very frequent and also not very interesting. @@ -266,23 +239,30 @@ inline void CuptiActivityProfiler::handleRuntimeActivity( VLOG(2) << activity->correlationId << ": CUPTI_ACTIVITY_KIND_RUNTIME, cbid=" << activity->cbid << " tid=" << activity->threadId; - const GenericTraceActivity& ext = - externalEvents_.correlatedActivity(activity->correlationId); int32_t tid = activity->threadId; const auto& it = resourceInfo_.find({processId(), tid}); if (it != resourceInfo_.end()) { tid = it->second.id; } - RuntimeActivity runtimeActivity(activity, ext, tid); - if (ext.correlationId() == 0 && outOfRange(runtimeActivity)) { + const ITraceActivity* linked = linkedActivity( + activity->correlationId, cpuCorrelationMap_); + const auto& runtime_activity = + traceBuffers_->addActivityWrapper(RuntimeActivity(activity, linked, tid)); + checkTimestampOrder(&runtime_activity); + if (outOfRange(runtime_activity)) { return; } - runtimeActivity.log(*logger); + runtime_activity.log(*logger); } -inline void CuptiActivityProfiler::updateGpuNetSpan(const ITraceActivity& gpuOp) { +inline void CuptiActivityProfiler::updateGpuNetSpan( + const ITraceActivity& gpuOp) { + if (!gpuOp.linkedActivity()) { + VLOG(0) << "Missing linked activity"; + return; + } const auto& it = clientActivityTraceMap_.find( - gpuOp.linkedActivity()->correlationId()); + gpuOp.linkedActivity()->correlationId()); if (it == clientActivityTraceMap_.end()) { // No correlation id mapping? return; @@ -297,50 +277,69 @@ inline void CuptiActivityProfiler::updateGpuNetSpan(const ITraceActivity& gpuOp) } // I've observed occasional broken timestamps attached to GPU events... -static bool timestampsInCorrectOrder( - const ITraceActivity& ext, - const ITraceActivity& gpuOp) { - if (ext.timestamp() > gpuOp.timestamp()) { - LOG(WARNING) << "GPU op timestamp (" << gpuOp.timestamp() - << ") < runtime timestamp (" << ext.timestamp() << ") by " - << ext.timestamp() - gpuOp.timestamp() << "us"; - LOG(WARNING) << "Name: " << gpuOp.name() - << " Device: " << gpuOp.deviceId() - << " Stream: " << gpuOp.resourceId(); - return false; - } - return true; +void CuptiActivityProfiler::checkTimestampOrder(const ITraceActivity* act1) { + // Correlated GPU runtime activity cannot + // have timestamp greater than the GPU activity's + const auto& it = correlatedCudaActivities_.find(act1->correlationId()); + if (it == correlatedCudaActivities_.end()) { + correlatedCudaActivities_.insert({act1->correlationId(), act1}); + return; + } + + // Activities may be appear in the buffers out of order. + // If we have a runtime activity in the map, it should mean that we + // have a GPU activity passed in, and vice versa. + const ITraceActivity* act2 = it->second; + if (act2->type() == ActivityType::CUDA_RUNTIME) { + // Buffer is out-of-order. + // Swap so that runtime activity is first for the comparison below. + std::swap(act1, act2); + } + if (act1->timestamp() > act2->timestamp()) { + LOG(WARNING) << "GPU op timestamp (" << act2->timestamp() + << ") < runtime timestamp (" << act1->timestamp() << ") by " + << act1->timestamp() - act2->timestamp() << "us"; + LOG(WARNING) << "Name: " << act2->name() + << " Device: " << act2->deviceId() + << " Stream: " << act2->resourceId(); + } } inline void CuptiActivityProfiler::handleGpuActivity( const ITraceActivity& act, ActivityLogger* logger) { - const ITraceActivity& ext = *act.linkedActivity(); - if (ext.timestamp() == 0 && outOfRange(act)) { + if (outOfRange(act)) { return; } - // Correlated GPU runtime activity cannot have timestamp greater than the GPU activity's - if (!timestampsInCorrectOrder(ext, act)) { - return; - } - - VLOG(2) << ext.correlationId() << "," << act.correlationId() << ": " + checkTimestampOrder(&act); + VLOG(2) << act.correlationId() << ": " << act.name(); recordStream(act.deviceId(), act.resourceId()); act.log(*logger); updateGpuNetSpan(act); - if (config_->selectedActivityTypes().count(ActivityType::GPU_USER_ANNOTATION) && - act.linkedActivity() && - act.linkedActivity()->type() == ActivityType::USER_ANNOTATION) { +} + +const ITraceActivity* CuptiActivityProfiler::linkedActivity( + int32_t correlationId, + const std::unordered_map& correlationMap) { + const auto& it = correlationMap.find(correlationId); + if (it != correlationMap.end()) { + const auto& it2 = activityMap_.find(it->second); + if (it2 != activityMap_.end()) { + return it2->second; + } } + return nullptr; } template inline void CuptiActivityProfiler::handleGpuActivity( const T* act, ActivityLogger* logger) { - const GenericTraceActivity& extDefault = - externalEvents_.correlatedActivity(act->correlationId); - handleGpuActivity(GpuActivity(act, extDefault), logger); + const ITraceActivity* linked = linkedActivity( + act->correlationId, cpuCorrelationMap_); + const auto& gpu_activity = + traceBuffers_->addActivityWrapper(GpuActivity(act, linked)); + handleGpuActivity(gpu_activity, logger); } void CuptiActivityProfiler::handleCuptiActivity(const CUpti_Activity* record, ActivityLogger* logger) { @@ -684,7 +683,9 @@ void CuptiActivityProfiler::resetTraceData() { cupti_.clearActivities(); } #endif // HAS_CUPTI || HAS_ROCTRACER - externalEvents_.clear(); + activityMap_.clear(); + cpuCorrelationMap_.clear(); + correlatedCudaActivities_.clear(); gpuUserEventMap_.clear(); traceSpans_.clear(); clientActivityTraceMap_.clear(); diff --git a/libkineto/src/CuptiActivityProfiler.h b/libkineto/src/CuptiActivityProfiler.h index 8f28b059f..4b4a4560c 100644 --- a/libkineto/src/CuptiActivityProfiler.h +++ b/libkineto/src/CuptiActivityProfiler.h @@ -133,41 +133,6 @@ class CuptiActivityProfiler { private: - class ExternalEventMap { - public: - - // The correlation id of the GPU activity - const libkineto::GenericTraceActivity& correlatedActivity( - uint32_t correlation_id); - void insertEvent(const libkineto::GenericTraceActivity* op); - - void addCorrelation(uint64_t external_id, uint32_t cuda_id); - - void clear() { - events_.clear(); - correlationMap_.clear(); - } - - private: - // Map extern correlation ID to Operator info. - // This is a map of regular pointers which is generally a bad idea, - // but this class also fully owns the objects it is pointing to so - // it's not so bad. This is done for performance reasons and is an - // implementation detail of this class that might change. - std::unordered_map - events_; - - // Cuda correlation id -> external correlation id for default events - // CUPTI provides a mechanism for correlating Cuda events to arbitrary - // external events, e.g.operator events from Caffe2. - // It also marks GPU activities with the Cuda event correlation ID. - // So by connecting the two, we get the complete picture. - std::unordered_map< - uint32_t, // Cuda correlation ID - uint64_t> // External correlation ID - correlationMap_; - }; - // Map of gpu activities to user defined events class GpuUserEventMap { public: @@ -194,6 +159,15 @@ class CuptiActivityProfiler { }; GpuUserEventMap gpuUserEventMap_; + // id -> activity* + std::unordered_map activityMap_; + // cuda runtime id -> pytorch op id + // CUPTI provides a mechanism for correlating Cuda events to arbitrary + // external events, e.g.operator activities from PyTorch. + std::unordered_map cpuCorrelationMap_; + // CUDA runtime <-> GPU Activity + std::unordered_map + correlatedCudaActivities_; // data structure to collect cuptiActivityFlushAll() latency overhead struct profilerOverhead { @@ -241,6 +215,10 @@ class CuptiActivityProfiler { // net name to id int netId(const std::string& netName); + const ITraceActivity* linkedActivity( + int32_t correlationId, + const std::unordered_map& correlationMap); + #ifdef HAS_CUPTI // Process generic CUPTI activity void handleCuptiActivity(const CUpti_Activity* record, ActivityLogger* logger); @@ -271,6 +249,8 @@ class CuptiActivityProfiler { return counter.overhead / counter.cntr; } + void checkTimestampOrder(const ITraceActivity* act1); + // On-demand request configuration std::unique_ptr config_; @@ -295,7 +275,6 @@ class CuptiActivityProfiler { std::chrono::time_point profileStartTime_; std::chrono::time_point profileEndTime_; - ExternalEventMap externalEvents_; // All recorded trace spans, both CPU and GPU // Trace Id -> list of iterations. diff --git a/libkineto/src/RoctracerActivityApi.cpp b/libkineto/src/RoctracerActivityApi.cpp index 2f21982ed..e2e6bd9f9 100644 --- a/libkineto/src/RoctracerActivityApi.cpp +++ b/libkineto/src/RoctracerActivityApi.cpp @@ -88,6 +88,9 @@ int RoctracerActivityApi::processActivities( a.resource = item.tid; a.activityType = ActivityType::CUDA_RUNTIME; a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; logger.handleGenericActivity(a); ++count; @@ -103,6 +106,9 @@ int RoctracerActivityApi::processActivities( a.resource = item.tid; a.activityType = ActivityType::CUDA_RUNTIME; a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; a.addMetadata("ptr", item.ptr); if (item.cid == HIP_API_ID_hipMalloc) { @@ -123,6 +129,9 @@ int RoctracerActivityApi::processActivities( a.resource = item.tid; a.activityType = ActivityType::CUDA_RUNTIME; a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; a.addMetadata("src", item.src); a.addMetadata("dst", item.dst); @@ -147,6 +156,9 @@ int RoctracerActivityApi::processActivities( a.resource = item.tid; a.activityType = ActivityType::CUDA_RUNTIME; a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; if (item.functionAddr != nullptr) { a.addMetadataQuoted( @@ -205,6 +217,9 @@ int RoctracerActivityApi::processActivities( a.activityType = ActivityType::CUDA_RUNTIME; a.activityName = std::string(name); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; logger.handleGenericActivity(a); ++count; @@ -226,6 +241,8 @@ int RoctracerActivityApi::processActivities( a.activityType = ActivityType::CONCURRENT_KERNEL; a.activityName = std::string(name); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; auto it = kernelNames_.find(record->correlation_id); if (it != kernelNames_.end()) { diff --git a/libkineto/src/output_base.h b/libkineto/src/output_base.h index 4bf1d4a10..12c193e40 100644 --- a/libkineto/src/output_base.h +++ b/libkineto/src/output_base.h @@ -65,8 +65,6 @@ class ActivityLogger { const libkineto::ITraceActivity& activity) = 0; #ifdef HAS_CUPTI - virtual void handleRuntimeActivity(const RuntimeActivity& activity) = 0; - virtual void handleGpuActivity( const GpuActivity& activity) = 0; virtual void handleGpuActivity( diff --git a/libkineto/src/output_json.cpp b/libkineto/src/output_json.cpp index 78efd8d74..a382d9524 100644 --- a/libkineto/src/output_json.cpp +++ b/libkineto/src/output_json.cpp @@ -189,9 +189,7 @@ void ChromeTraceLogger::handleTraceSpan(const TraceSpan& span) { 0x20000000ll); // clang-format on - if (span.tracked) { - addIterationMarker(span); - } + addIterationMarker(span); } void ChromeTraceLogger::addIterationMarker(const TraceSpan& span) { @@ -253,7 +251,14 @@ void ChromeTraceLogger::handleGenericActivity( const std::string op_metadata = op.metadataJson(); std::string separator = ""; if (op_metadata.find_first_not_of(" \t\n") != std::string::npos) { - separator = ","; + separator = ",\n "; + } + std::string span = ""; + if (op.traceSpan()) { + span = fmt::format(R"JSON( + "Trace name": "{}", "Trace iteration": {},)JSON", + op.traceSpan()->name, + op.traceSpan()->iteration); } const std::string tid = op.type() == ActivityType::GPU_USER_ANNOTATION ? @@ -261,67 +266,46 @@ void ChromeTraceLogger::handleGenericActivity( fmt::format("{}", op.resourceId()); // clang-format off - - switch (op.type()) { - case ActivityType::CUDA_RUNTIME: - { - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Runtime", {}, - "args": {{ - {} - }} - }},)JSON", - traceActivityJson(op), - op_metadata); - handleLink(kFlowStart, op, op.correlationId(), "async_gpu", "async_gpu"); - } - break; - case ActivityType::CONCURRENT_KERNEL: - { - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Kernel", {}, - "args": {{ - {} - }} - }},)JSON", - traceActivityJson(op), - op_metadata); - handleLink(kFlowEnd, op, op.correlationId(), "async_gpu", "async_gpu"); - } - break; - default: - { - traceOf_ << fmt::format(R"JSON( + traceOf_ << fmt::format(R"JSON( {{ "ph": "X", "cat": "{}", {}, - "args": {{ - "External id": {}, - "Trace name": "{}", "Trace iteration": {}{} - {} + "args": {{{} + "External id": {}{}{} }} }},)JSON", - toString(op.type()), traceActivityJson(op), - // args - op.correlationId(), - op.traceSpan()->name, op.traceSpan()->iteration, separator, - op_metadata); - } - break; - } + toString(op.type()), traceActivityJson(op), + // args + span, + op.correlationId(), separator, op_metadata); // clang-format on - if (op.linkedActivity() != nullptr) { + if (op.flowId() > 0) { handleGenericLink(op); } } void ChromeTraceLogger::handleGenericLink(const ITraceActivity& act) { - if (act.flowType() == kLinkFwdBwd) { - const auto& from_act = *act.linkedActivity(); - handleLink(kFlowStart, from_act, act.flowId(), "forward_backward", "fwd_bwd"); - handleLink(kFlowEnd, act, act.flowId(), "forward_backward", "fwd_bwd"); + static struct { + int type; + char longName[24]; + char shortName[16]; + } flow_names[] = { + {kLinkFwdBwd, "forward_backward", "fwd_bwd"}, + {kLinkAsyncCpuGpu, "async_cpu_to_gpu", "async_gpu"} + }; + for (auto& flow : flow_names) { + if (act.flowType() == flow.type) { + // Link the activities via flow ID in source and destination. + // The source node must return true from flowStart() + // and the destination node false. + if (act.flowStart()) { + handleLink(kFlowStart, act, act.flowId(), flow.longName, flow.shortName); + } else { + handleLink(kFlowEnd, act, act.flowId(), flow.longName, flow.shortName); + } + return; + } } + LOG(ERROR) << "Unknown flow type: " << act.flowType(); } void ChromeTraceLogger::handleLink( @@ -345,42 +329,6 @@ void ChromeTraceLogger::handleLink( } #ifdef HAS_CUPTI -void ChromeTraceLogger::handleRuntimeActivity( - const RuntimeActivity& activity) { - if (!traceOf_) { - return; - } - - const CUpti_CallbackId cbid = activity.raw().cbid; - const ITraceActivity& ext = *activity.linkedActivity(); - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Runtime", {}, - "args": {{ - "cbid": {}, "correlation": {}, - "external id": {}, "external ts": {} - }} - }},)JSON", - traceActivityJson(activity), - // args - cbid, activity.raw().correlationId, - ext.correlationId(), ext.timestamp()); - // clang-format on - - // FIXME: This is pretty hacky and it's likely that we miss some links. - // May need to maintain a map instead. - if (cbid == CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 || - (cbid >= CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 && - cbid <= CUPTI_RUNTIME_TRACE_CBID_cudaMemset2DAsync_v3020) || - cbid == - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000 || - cbid == - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000) { - auto from_id = activity.correlationId(); - handleLink(kFlowStart, activity, from_id, "async_gpu", activity.name()); - } -} - // GPU side kernel activity void ChromeTraceLogger::handleGpuActivity( const GpuActivity& activity) { @@ -388,7 +336,6 @@ void ChromeTraceLogger::handleGpuActivity( return; } const CUpti_ActivityKernel4* kernel = &activity.raw(); - const ITraceActivity& ext = *activity.linkedActivity(); constexpr int threads_per_warp = 32; float blocks_per_sm = -1.0; float warps_per_sm = -1.0; @@ -418,7 +365,7 @@ void ChromeTraceLogger::handleGpuActivity( "ph": "X", "cat": "Kernel", {}, "args": {{ "queued": {}, "device": {}, "context": {}, - "stream": {}, "correlation": {}, "external id": {}, + "stream": {}, "correlation": {}, "registers per thread": {}, "shared memory": {}, "blocks per SM": {}, @@ -431,7 +378,7 @@ void ChromeTraceLogger::handleGpuActivity( traceActivityJson(activity), // args us(kernel->queued), kernel->deviceId, kernel->contextId, - kernel->streamId, kernel->correlationId, ext.correlationId(), + kernel->streamId, kernel->correlationId, kernel->registersPerThread, kernel->staticSharedMemory + kernel->dynamicSharedMemory, blocks_per_sm, @@ -442,7 +389,7 @@ void ChromeTraceLogger::handleGpuActivity( // clang-format on auto to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_gpu", "cudaLaunchKernel"); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); } static std::string bandwidth(uint64_t bytes, uint64_t duration) { @@ -456,7 +403,6 @@ void ChromeTraceLogger::handleGpuActivity( return; } const CUpti_ActivityMemcpy& memcpy = activity.raw(); - const ITraceActivity& ext = *activity.linkedActivity(); VLOG(2) << memcpy.correlationId << ": MEMCPY"; // clang-format off traceOf_ << fmt::format(R"JSON( @@ -464,19 +410,19 @@ void ChromeTraceLogger::handleGpuActivity( "ph": "X", "cat": "Memcpy", {}, "args": {{ "device": {}, "context": {}, - "stream": {}, "correlation": {}, "external id": {}, + "stream": {}, "correlation": {}, "bytes": {}, "memory bandwidth (GB/s)": {} }} }},)JSON", traceActivityJson(activity), // args memcpy.deviceId, memcpy.contextId, - memcpy.streamId, memcpy.correlationId, ext.correlationId(), + memcpy.streamId, memcpy.correlationId, memcpy.bytes, bandwidth(memcpy.bytes, memcpy.end - memcpy.start)); // clang-format on int64_t to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_gpu", "cudaMemcpyAsync"); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); } // GPU side memcpy activity @@ -486,7 +432,6 @@ void ChromeTraceLogger::handleGpuActivity( return; } const CUpti_ActivityMemcpy2& memcpy = activity.raw(); - const ITraceActivity& ext = *activity.linkedActivity(); // clang-format off traceOf_ << fmt::format(R"JSON( {{ @@ -494,7 +439,7 @@ void ChromeTraceLogger::handleGpuActivity( "args": {{ "fromDevice": {}, "inDevice": {}, "toDevice": {}, "fromContext": {}, "inContext": {}, "toContext": {}, - "stream": {}, "correlation": {}, "external id": {}, + "stream": {}, "correlation": {}, "bytes": {}, "memory bandwidth (GB/s)": {} }} }},)JSON", @@ -502,12 +447,12 @@ void ChromeTraceLogger::handleGpuActivity( // args memcpy.srcDeviceId, memcpy.deviceId, memcpy.dstDeviceId, memcpy.srcContextId, memcpy.contextId, memcpy.dstContextId, - memcpy.streamId, memcpy.correlationId, ext.correlationId(), + memcpy.streamId, memcpy.correlationId, memcpy.bytes, bandwidth(memcpy.bytes, memcpy.end - memcpy.start)); // clang-format on int64_t to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_gpu", "cudaMemcpyAsync"); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); } void ChromeTraceLogger::handleGpuActivity( @@ -516,26 +461,25 @@ void ChromeTraceLogger::handleGpuActivity( return; } const CUpti_ActivityMemset& memset = activity.raw(); - const ITraceActivity& ext = *activity.linkedActivity(); // clang-format off traceOf_ << fmt::format(R"JSON( {{ "ph": "X", "cat": "Memset", {}, "args": {{ "device": {}, "context": {}, - "stream": {}, "correlation": {}, "external id": {}, + "stream": {}, "correlation": {}, "bytes": {}, "memory bandwidth (GB/s)": {} }} }},)JSON", traceActivityJson(activity), // args memset.deviceId, memset.contextId, - memset.streamId, memset.correlationId, ext.correlationId(), + memset.streamId, memset.correlationId, memset.bytes, bandwidth(memset.bytes, memset.end - memset.start)); // clang-format on int64_t to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_gpu", "cudaMemsetAsync"); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); } #endif // HAS_CUPTI diff --git a/libkineto/src/output_json.h b/libkineto/src/output_json.h index 880c5e005..29f1e3416 100644 --- a/libkineto/src/output_json.h +++ b/libkineto/src/output_json.h @@ -45,9 +45,6 @@ class ChromeTraceLogger : public libkineto::ActivityLogger { void handleGenericActivity(const ITraceActivity& activity) override; #ifdef HAS_CUPTI - void handleRuntimeActivity( - const RuntimeActivity& activity) override; - void handleGpuActivity(const GpuActivity& activity) override; void handleGpuActivity(const GpuActivity& activity) override; void handleGpuActivity(const GpuActivity& activity) override; diff --git a/libkineto/src/output_membuf.h b/libkineto/src/output_membuf.h index bb4ab36e3..f4cd550da 100644 --- a/libkineto/src/output_membuf.h +++ b/libkineto/src/output_membuf.h @@ -63,11 +63,6 @@ class MemoryTraceLogger : public ActivityLogger { activities_.push_back(wrappers_.back().get()); } - void handleRuntimeActivity( - const RuntimeActivity& activity) override { - addActivityWrapper(activity); - } - void handleGpuActivity(const GpuActivity& activity) override { addActivityWrapper(activity); } diff --git a/libkineto/test/CuptiActivityProfilerTest.cpp b/libkineto/test/CuptiActivityProfilerTest.cpp index 43b1048df..c2c65679d 100644 --- a/libkineto/test/CuptiActivityProfilerTest.cpp +++ b/libkineto/test/CuptiActivityProfilerTest.cpp @@ -348,54 +348,6 @@ TEST_F(CuptiActivityProfilerTest, SyncTrace) { #endif } -TEST_F(CuptiActivityProfilerTest, CorrelatedTimestampTest) { - // Verbose logging is useful for debugging - std::vector log_modules( - {"CuptiActivityProfiler.cpp"}); - SET_LOG_VERBOSITY_LEVEL(2, log_modules); - - // Start and stop profiling - CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); - int64_t start_time_us = 100; - int64_t duration_us = 300; - auto start_time = time_point(microseconds(start_time_us)); - profiler.configure(*cfg_, start_time); - profiler.startTrace(start_time); - profiler.stopTrace(start_time + microseconds(duration_us)); - - // Scenario 1: Test mismatch in CPU and GPU events. - // When launching kernel, the CPU event should always precede the GPU event. - int64_t kernelLaunchTime = 120; - - profiler.recordThreadInfo(); - - // set up CPU event - auto cpuOps = std::make_unique( - start_time_us, start_time_us + duration_us); - cpuOps->addOp("launchKernel", kernelLaunchTime, kernelLaunchTime + 10, 1); - profiler.transferCpuTrace(std::move(cpuOps)); - - // set up GPU event - auto gpuOps = std::make_unique(); - gpuOps->addCorrelationActivity(1, CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, 1); - gpuOps->addKernelActivity(kernelLaunchTime - 1, kernelLaunchTime + 10, 1); - cuptiActivities_.activityBuffer = std::move(gpuOps); - - // process trace - auto logger = std::make_unique(*cfg_); - profiler.processTrace(*logger); - - ActivityTrace trace(std::move(logger), loggerFactory); - std::map counts; - for (auto& activity : *trace.activities()) { - counts[activity->name()]++; - } - - // The GPU launch kernel activities should have been dropped due to invalid timestamps - EXPECT_EQ(counts["cudaLaunchKernel"], 0); - EXPECT_EQ(counts["launchKernel"], 1); -} - TEST_F(CuptiActivityProfilerTest, SubActivityProfilers) { using ::testing::Return; using ::testing::ByMove;