Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libkineto/libkineto_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_libkineto_srcs(with_api = True):
"src/init.cpp",
"src/output_csv.cpp",
"src/output_json.cpp",
"src/CudaDeviceProperties.cpp",
] + (get_libkineto_api_srcs() if with_api else [])

def get_libkineto_cpu_only_srcs(with_api = True):
Expand Down
73 changes: 73 additions & 0 deletions libkineto/src/CudaDeviceProperties.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check whether we add a different copyright header...

* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "CudaDeviceProperties.h"

#include <vector>

#include <cuda_runtime.h>
#include <cuda_occupancy.h>

namespace KINETO_NAMESPACE {

std::vector<cudaOccDeviceProp> createOccDeviceProps() {
std::vector<cudaOccDeviceProp> occProps;
int device_count;
cudaError_t error_id = cudaGetDeviceCount(&device_count);
// Return empty vector if error.
if (error_id != cudaSuccess) {
return occProps;
}
for (int i = 0; i < device_count; ++i) {
cudaDeviceProp prop;
error_id = cudaGetDeviceProperties(&prop, i);
// Return empty vector if any device property fail to get.
if (error_id != cudaSuccess) {
return occProps;
}
cudaOccDeviceProp occProp;
occProp = prop;
occProps.push_back(occProp);
}
return occProps;
}

const std::vector<cudaOccDeviceProp>& occDeviceProps() {
static std::vector<cudaOccDeviceProp> occProps = createOccDeviceProps();
return occProps;
}

float getKernelOccupancy(uint32_t deviceId, uint16_t registersPerThread,
int32_t staticSharedMemory, int32_t dynamicSharedMemory,
int32_t blockX, int32_t blockY, int32_t blockZ) {
// Calculate occupancy
float occupancy = -1.0;
const std::vector<cudaOccDeviceProp> &occProps = occDeviceProps();
if (deviceId < occProps.size()) {
cudaOccFuncAttributes occFuncAttr;
occFuncAttr.maxThreadsPerBlock = INT_MAX;
occFuncAttr.numRegs = registersPerThread;
occFuncAttr.sharedSizeBytes = staticSharedMemory;
occFuncAttr.partitionedGCConfig = PARTITIONED_GC_OFF;
occFuncAttr.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT;
occFuncAttr.maxDynamicSharedSizeBytes = 0;
const cudaOccDeviceState occDeviceState = {};
int blockSize = blockX * blockY * blockZ;
size_t dynamicSmemSize = dynamicSharedMemory;
cudaOccResult occ_result;
cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor(
&occ_result, &occProps[deviceId], &occFuncAttr, &occDeviceState,
blockSize, dynamicSmemSize);
if (status == CUDA_OCC_SUCCESS) {
occupancy = occ_result.activeBlocksPerMultiprocessor * blockSize /
(float) occProps[deviceId].maxThreadsPerMultiprocessor;
}
}
return occupancy;
}

} // namespace KINETO_NAMESPACE
18 changes: 18 additions & 0 deletions libkineto/src/CudaDeviceProperties.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <stdint.h>

namespace KINETO_NAMESPACE {

float getKernelOccupancy(uint32_t deviceId, uint16_t registersPerThread,
int32_t staticSharedMemory, int32_t dynamicSharedMemory,
int32_t blockX, int32_t blockY, int32_t blockZ);

} // namespace KINETO_NAMESPACE
23 changes: 15 additions & 8 deletions libkineto/src/output_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "CuptiActivity.h"
#include "CuptiActivity.tpp"
#include "CuptiActivityInterface.h"
#include "CudaDeviceProperties.h"
#endif // HAS_CUPTI
#include "Demangle.h"
#include "TraceSpan.h"
Expand Down Expand Up @@ -311,12 +312,16 @@ void ChromeTraceLogger::handleGpuActivity(
const CUpti_ActivityKernel4* kernel = &activity.raw();
const TraceActivity& ext = *activity.linkedActivity();
constexpr int threads_per_warp = 32;
float warps_per_sm = -1.0;
float blocks_per_sm = -1.0;
if (smCount_) {
warps_per_sm = (kernel->gridX * kernel->gridY * kernel->gridZ) *
(kernel->blockX * kernel->blockY * kernel->blockZ) /
(float) threads_per_warp / smCount_;
blocks_per_sm = (kernel->gridX * kernel->gridY * kernel->gridZ) / (float) smCount_;
}

// Calculate occupancy
float occupancy = KINETO_NAMESPACE::getKernelOccupancy(kernel->deviceId, kernel->registersPerThread,
kernel->staticSharedMemory, kernel->dynamicSharedMemory,
kernel->blockX, kernel->blockY, kernel->blockZ);

// clang-format off
traceOf_ << fmt::format(R"JSON(
{{
Expand All @@ -326,9 +331,10 @@ void ChromeTraceLogger::handleGpuActivity(
"stream": {}, "correlation": {}, "external id": {},
"registers per thread": {},
"shared memory": {},
"warps per SM": {},
"blocks per SM": {},
"grid": [{}, {}, {}],
"block": [{}, {}, {}]
"block": [{}, {}, {}],
"occupancy": {}
}}
}},)JSON",
traceActivityJson(activity, "stream "),
Expand All @@ -337,9 +343,10 @@ void ChromeTraceLogger::handleGpuActivity(
kernel->streamId, kernel->correlationId, ext.correlationId(),
kernel->registersPerThread,
kernel->staticSharedMemory + kernel->dynamicSharedMemory,
warps_per_sm,
blocks_per_sm,
kernel->gridX, kernel->gridY, kernel->gridZ,
kernel->blockX, kernel->blockY, kernel->blockZ);
kernel->blockX, kernel->blockY, kernel->blockZ,
occupancy);
// clang-format on

handleLinkEnd(activity);
Expand Down