Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions backends/cuda/runtime/memory_tracker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <algorithm>
#include <limits>

#include <executorch/runtime/platform/log.h>

namespace executorch::backends::cuda {

/**
* @class CudaMemoryTracker
* @brief Tracks CUDA memory usage and logs memory state at key points
*
* This class provides utilities to query and track CUDA memory usage,
* including peak memory usage and detailed memory state logging.
*/
class CudaMemoryTracker {
public:
/**
* @brief Constructor - initializes tracker and logs startup memory state
*/
CudaMemoryTracker() {
if (!query(&last_free_bytes_, &total_bytes_)) {
return;
}
available_ = true;
// Record the initial free bytes observed at startup. We'll use this as a
// baseline so reported "peak usage" reflects additional memory used
// since the tracker was created (instead of the absolute device usage,
// which may include other processes).
initial_free_bytes_ = last_free_bytes_;
min_free_bytes_ = last_free_bytes_;
log_state("startup", last_free_bytes_, total_bytes_);
}

/**
* @brief Logs current memory state at a tagged checkpoint
* @param tag Descriptive tag for this memory sample (e.g., "after_load")
*/
void log_sample(const char* tag) {
if (!available_) {
return;
}
size_t free_bytes = 0;
size_t total_bytes = 0;
if (!query(&free_bytes, &total_bytes)) {
return;
}
min_free_bytes_ = std::min(min_free_bytes_, free_bytes);
total_bytes_ = total_bytes;
last_free_bytes_ = free_bytes;
log_state(tag, free_bytes, total_bytes);
}

/**
* @brief Destructor - logs final memory state and peak usage summary
*/
~CudaMemoryTracker() {
if (!available_) {
return;
}
size_t free_bytes = 0;
size_t total_bytes = 0;
if (!query(&free_bytes, &total_bytes)) {
return;
}
min_free_bytes_ = std::min(min_free_bytes_, free_bytes);
total_bytes_ = total_bytes;
last_free_bytes_ = free_bytes;
// Compute peak usage relative to the initial free baseline so that
// allocations by other processes present at startup are not attributed
// to this process. If for some reason initial_free_bytes_ was not set,
// fall back to absolute device usage.
double peak_mb = 0.0;
if (initial_free_bytes_ != std::numeric_limits<size_t>::max()) {
size_t used_delta = 0;
if (initial_free_bytes_ > min_free_bytes_) {
used_delta = initial_free_bytes_ - min_free_bytes_;
}
peak_mb = static_cast<double>(used_delta) / (1024.0 * 1024.0);
} else {
peak_mb = static_cast<double>(total_bytes_ - min_free_bytes_) /
(1024.0 * 1024.0);
}
const double total_mb =
static_cast<double>(total_bytes_) / (1024.0 * 1024.0);
ET_LOG(
Info,
"CUDA memory peak usage (since startup): %.2f MB, device total: %.2f MB",
peak_mb,
total_mb);
}

private:
/**
* @brief Queries current CUDA memory info
* @param free_bytes Output parameter for free memory in bytes
* @param total_bytes Output parameter for total memory in bytes
* @return true if query succeeded, false otherwise
*/
bool query(size_t* free_bytes, size_t* total_bytes) {
cudaError_t err = cudaMemGetInfo(free_bytes, total_bytes);
if (err != cudaSuccess) {
if (!error_logged_) {
error_logged_ = true;
ET_LOG(
Error,
"cudaMemGetInfo failed with error: %s",
cudaGetErrorString(err));
}
available_ = false;
return false;
}
return true;
}

/**
* @brief Logs the current memory state
* @param tag Tag describing this log point
* @param free_bytes Current free memory in bytes
* @param total_bytes Current total memory in bytes
*/
void log_state(const char* tag, size_t free_bytes, size_t total_bytes) const {
const double used_mb =
static_cast<double>(total_bytes - free_bytes) / (1024.0 * 1024.0);
const double free_mb = static_cast<double>(free_bytes) / (1024.0 * 1024.0);
const double total_mb =
static_cast<double>(total_bytes) / (1024.0 * 1024.0);
ET_LOG(
Info,
"CUDA memory (%s): used %.2f MB, free %.2f MB, total %.2f MB",
tag,
used_mb,
free_mb,
total_mb);
}

bool available_{false};
bool error_logged_{false};
size_t last_free_bytes_{0};
size_t total_bytes_{0};
size_t min_free_bytes_{std::numeric_limits<size_t>::max()};
// Baseline free bytes observed at tracker construction. Used to compute
// peak usage attributable to this process since the tracker started.
size_t initial_free_bytes_{std::numeric_limits<size_t>::max()};

public:
// Simple accessors to allow other components to read last-sampled values.
// These are safe to call after a successful log_sample() invocation.
uint64_t last_free_bytes() const {
return static_cast<uint64_t>(last_free_bytes_);
}
uint64_t total_bytes() const {
return static_cast<uint64_t>(total_bytes_);
}
uint64_t min_free_bytes() const {
return static_cast<uint64_t>(min_free_bytes_);
}
uint64_t initial_free_bytes() const {
return static_cast<uint64_t>(initial_free_bytes_);
}
double peak_usage_mb() const {
// Prefer peak relative to the initial free baseline; fall back to
// absolute device peak if baseline isn't available.
if (min_free_bytes_ == std::numeric_limits<size_t>::max()) {
return 0.0;
}
if (initial_free_bytes_ != std::numeric_limits<size_t>::max()) {
size_t used_delta = 0;
if (initial_free_bytes_ > min_free_bytes_) {
used_delta = initial_free_bytes_ - min_free_bytes_;
}
return static_cast<double>(used_delta) / (1024.0 * 1024.0);
}
if (total_bytes_ == 0) {
return 0.0;
}
return static_cast<double>(total_bytes_ - min_free_bytes_) /
(1024.0 * 1024.0);
}
};

} // namespace executorch::backends::cuda
19 changes: 19 additions & 0 deletions extension/llm/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ target_include_directories(
extension_llm_runner INTERFACE ${_common_include_directories}
)

# If the project is configured to build with CUDA support, try to find a CUDA
# runtime (prefer the CUDAToolkit package). If found, expose a compile-time
# macro so sources can conditionally compile CUDA-aware code.
if(EXECUTORCH_BUILD_CUDA)
# Prefer the modern CMake CUDAToolkit module, fall back to searching for the
# CUDA runtime library (cudart) if the package isn't available.
find_package(CUDAToolkit QUIET)
if(CUDAToolkit_FOUND)
target_compile_definitions(extension_llm_runner PUBLIC CUDA_AVAILABLE)
target_link_libraries(extension_llm_runner PUBLIC CUDA::cudart)
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE")
else()
message(
STATUS
"CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found"
)
endif()
endif()

install(
TARGETS extension_llm_runner
EXPORT ExecuTorchTargets
Expand Down
36 changes: 33 additions & 3 deletions extension/llm/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include <pytorch/tokenizers/hf_tokenizer.h>
#include <pytorch/tokenizers/sentencepiece.h>

#ifdef CUDA_AVAILABLE
#include <executorch/backends/cuda/runtime/memory_tracker.h>
#endif

namespace executorch::extension::llm {

using ::executorch::extension::Module;
Expand All @@ -38,7 +42,16 @@ MultimodalRunner::MultimodalRunner(
io_manager_(std::move(io_manager)),
text_token_generator_(std::move(text_token_generator)),
stats_(std::move(stats)),
pos_(0) {}
pos_(0) {
#ifdef CUDA_AVAILABLE
cuda_memory_tracker_ =
std::make_unique<::executorch::backends::cuda::CudaMemoryTracker>();
// Probe immediately after creating the tracker to capture GPU state before
// any model loading happens.
stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes();
stats_->gpu_free_before_load_bytes = cuda_memory_tracker_->last_free_bytes();
#endif
}

bool MultimodalRunner::is_loaded() {
return multimodal_prefiller_->is_method_loaded() &&
Expand All @@ -49,8 +62,18 @@ Error MultimodalRunner::load() {
if (is_loaded()) {
return Error::Ok;
}
stats_->model_load_start_ms = time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(multimodal_prefiller_->load());
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
stats_->model_load_end_ms = time_in_ms();

#ifdef CUDA_AVAILABLE
cuda_memory_tracker_->log_sample("after_load");
stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes();
stats_->gpu_free_after_load_bytes = cuda_memory_tracker_->last_free_bytes();
stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb();
#endif

return Error::Ok;
}

Expand Down Expand Up @@ -86,9 +109,7 @@ Error MultimodalRunner::generate(
}

if (!is_loaded()) {
stats_->model_load_start_ms = time_in_ms();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to delete the loading time recording?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it inside load()

ET_CHECK_OK_OR_RETURN_ERROR(load());
stats_->model_load_end_ms = time_in_ms();
}

if (config.warming) {
Expand Down Expand Up @@ -192,6 +213,15 @@ Error MultimodalRunner::generate(
stats_->num_generated_tokens = num_generated_tokens;
// Finalize stats and call callback
stats_->inference_end_ms = time_in_ms();

#ifdef CUDA_AVAILABLE
cuda_memory_tracker_->log_sample("after_generate");
stats_->gpu_free_after_generate_bytes =
cuda_memory_tracker_->last_free_bytes();
// update peak in case it changed after generation
stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb();
#endif

if (!config.warming) {
printf("\n");
}
Expand Down
9 changes: 9 additions & 0 deletions extension/llm/runner/multimodal_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
// These are provided for backward compatibility
#include <executorch/extension/llm/runner/llm_runner_helper.h>

#ifdef CUDA_AVAILABLE
#include <executorch/backends/cuda/runtime/memory_tracker.h>
#endif

namespace executorch {
namespace extension {
namespace llm {
Expand Down Expand Up @@ -150,6 +154,11 @@ class ET_EXPERIMENTAL MultimodalRunner {
std::unique_ptr<TextTokenGenerator> text_token_generator_;
std::unique_ptr<Stats> stats_;

#ifdef CUDA_AVAILABLE
std::unique_ptr<::executorch::backends::cuda::CudaMemoryTracker>
cuda_memory_tracker_;
#endif

// Internal state
int64_t pos_;
};
Expand Down
Loading
Loading