Skip to content

Logging SetLogCallback + Debugging cleanup #1471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions src/logging.cpp
Original file line number Diff line number Diff line change
@@ -12,6 +12,33 @@ namespace Generators {
LogItems g_log;
static std::ostream* gp_stream{&std::cerr};
static std::unique_ptr<std::ofstream> gp_logfile;
static CallbackFn gp_callback{};

// Custom stream that calls gp_callback on every line of output
struct CallbackStream : std::ostream {
CallbackStream() : std::ostream{&m_buffer} {}

struct CustomBuffer : std::stringbuf {
int sync() override {
auto string = str();
if (gp_callback)
gp_callback(string.c_str(), string.size());
str("");
return 0;
}
};

CustomBuffer m_buffer;
} gp_callback_stream;

void SetLogStream() {
if (gp_callback)
gp_stream = &gp_callback_stream;
else if (gp_logfile)
gp_stream = gp_logfile.get();
else
gp_stream = &std::cerr;
}

void SetLogBool(std::string_view name, bool value) {
if (name == "enabled")
@@ -51,16 +78,25 @@ void SetLogString(std::string_view name, std::string_view value) {
else {
fs::path filename{std::string(value)};
gp_logfile = std::make_unique<std::ofstream>(filename.open_for_write());
// If a filename was provided, log callback will be disabled
gp_callback = nullptr;
}

if (gp_logfile)
gp_stream = gp_logfile.get();
else
gp_stream = &std::cerr;
SetLogStream();
} else
throw JSON::unknown_value_error{};
}

void SetLogCallback(CallbackFn fn) {
gp_callback = fn;
// If a callback was provided, file logging will be disabled
if (gp_callback) {
gp_logfile.reset();
}

SetLogStream();
}

std::ostream& operator<<(std::ostream& stream, SGR sgr_code) {
if (g_log.ansi_tags) {
stream << "\x1b[" << static_cast<int>(sgr_code) << 'm';
3 changes: 3 additions & 0 deletions src/logging.h
Original file line number Diff line number Diff line change
@@ -24,8 +24,11 @@
*/
namespace Generators {

using CallbackFn = void (*)(const char* string, size_t length);

void SetLogBool(std::string_view name, bool value);
void SetLogString(std::string_view name, std::string_view value);
void SetLogCallback(CallbackFn callback);

struct LogItems {
// Special log related entries
13 changes: 1 addition & 12 deletions src/models/debugging.cpp
Original file line number Diff line number Diff line change
@@ -71,17 +71,6 @@ struct Stats {
}
};

template <typename... Types>
const char* TypeToString(ONNXTensorElementDataType type, Ort::TypeList<Types...>) {
const char* name = "(please add type to list)";
(void)((type == Ort::TypeToTensorType<Types> ? name = typeid(Types).name(), true : false) || ...);
return name;
}

const char* TypeToString(ONNXTensorElementDataType type) {
return TypeToString(type, Ort::TensorTypes{});
}

std::ostream& operator<<(std::ostream& stream, Ort::Float16_t v) {
stream << Float16ToFloat32(v);
return stream;
@@ -178,7 +167,7 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
case OrtMemoryInfoDeviceType_GPU: {
stream << "GPU\r\n";
auto type = type_info->GetElementType();
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), Ort::SizeOf(type) * element_count};
auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
break;
2 changes: 1 addition & 1 deletion src/models/logits.cpp
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ DeviceSpan<float> Logits::Get() {

logits_of_last_token = output_last_tokens_.get();

size_t element_size = SizeOf(type_);
size_t element_size = Ort::SizeOf(type_);
size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process

auto logits_raw = output_raw_->GetByteSpan();
2 changes: 1 addition & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
@@ -880,7 +880,7 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
auto element_type = input_type_info->GetElementType();
auto input_shape = input_type_info->GetShape();
const int64_t batch_size = input_shape[0];
const int64_t data_size_bytes = input_type_info->GetElementCount() * SizeOf(element_type) / batch_size;
const int64_t data_size_bytes = input_type_info->GetElementCount() * Ort::SizeOf(element_type) / batch_size;

input_shape[0] *= num_beams;

32 changes: 16 additions & 16 deletions src/models/utils.cpp
Original file line number Diff line number Diff line change
@@ -7,39 +7,39 @@ namespace Generators {

DeviceSpan<uint8_t> ByteWrapTensor(DeviceInterface& device, OrtValue& value) {
auto info = value.GetTensorTypeAndShapeInfo();
return device.WrapMemory(std::span<uint8_t>{value.GetTensorMutableData<uint8_t>(), info->GetElementCount() * SizeOf(info->GetElementType())});
return device.WrapMemory(std::span<uint8_t>{value.GetTensorMutableData<uint8_t>(), info->GetElementCount() * Ort::SizeOf(info->GetElementType())});
}

size_t SizeOf(ONNXTensorElementDataType type) {
const char* TypeToString(ONNXTensorElementDataType type) {
switch (type) {
case Ort::TypeToTensorType<uint8_t>:
return sizeof(uint8_t);
return "uint8";
case Ort::TypeToTensorType<int8_t>:
return sizeof(int8_t);
return "int8";
case Ort::TypeToTensorType<uint16_t>:
return sizeof(uint16_t);
return "uint16";
case Ort::TypeToTensorType<int16_t>:
return sizeof(int16_t);
return "int16";
case Ort::TypeToTensorType<uint32_t>:
return sizeof(uint32_t);
return "uint32";
case Ort::TypeToTensorType<int32_t>:
return sizeof(int32_t);
return "int32";
case Ort::TypeToTensorType<uint64_t>:
return sizeof(int64_t);
return "uint64";
case Ort::TypeToTensorType<int64_t>:
return sizeof(int64_t);
return "int64";
case Ort::TypeToTensorType<bool>:
return sizeof(bool);
return "bool";
case Ort::TypeToTensorType<float>:
return sizeof(float);
return "float32";
case Ort::TypeToTensorType<double>:
return sizeof(double);
return "float64";
case Ort::TypeToTensorType<Ort::Float16_t>:
return sizeof(Ort::Float16_t);
return "float16";
case Ort::TypeToTensorType<Ort::BFloat16_t>:
return sizeof(Ort::BFloat16_t);
return "bfloat16";
default:
throw std::runtime_error("Unsupported ONNXTensorElementDataType in GetTypeSize");
return "(unsupported type, please add)";
}
}

2 changes: 1 addition & 1 deletion src/models/utils.h
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ struct OrtxPtr {
T* p_{};
};

size_t SizeOf(ONNXTensorElementDataType type);
const char* TypeToString(ONNXTensorElementDataType type);

int64_t ElementCountFromShape(std::span<const int64_t> shape);

4 changes: 4 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
@@ -676,6 +676,10 @@ inline void SetLogString(const char* name, const char* value) {
OgaCheckResult(OgaSetLogString(name, value));
}

inline void SetLogCallback(void (*callback)(const char* string, size_t length)) {
OgaCheckResult(OgaSetLogCallback(callback));
}

inline void SetCurrentGpuDeviceId(int device_id) {
OgaCheckResult(OgaSetCurrentGpuDeviceId(device_id));
}
7 changes: 7 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
@@ -107,6 +107,13 @@ OgaResult* OGA_API_CALL OgaSetLogString(const char* name, const char* value) {
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaSetLogCallback(void (*callback)(const char* string, size_t length)) {
OGA_TRY
Generators::SetLogCallback(callback);
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out) {
OGA_TRY
*out = ReturnUnique<OgaSequences>(std::make_unique<Generators::TokenSequences>());
14 changes: 14 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
@@ -97,12 +97,26 @@ OGA_EXPORT void OGA_API_CALL OgaShutdown();
OGA_EXPORT const char* OGA_API_CALL OgaResultGetError(const OgaResult* result);

/**
* \brief Control the logging behavior of the library.
* If OgaSetLogString is called with name "filename", and value is a valid file path,
* the library will log to that file. This will override any previously set logging destination.
* If OgaSetLogString is called with name "filename" and the value provided is an empty string,
* the library will log to the default destination (i.e. std::cerr) thereafter.
* \param[in] name logging option name, see logging.h 'struct LogItems' for the list of available options
* \param[in] value logging option value.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogBool(const char* name, bool value);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogString(const char* name, const char* value);

/**
* \brief Register a callback function to receive log messages from the library. If invoked, the callback will override
* the previously set logging destination (e.g. a file or std::cerr).
* \param[in] callback function pointer to the logging callback function (use nullptr to disable callback and revert to
* the default logging destination - std::cerr).
* \return OgaResult containing the error message when the callback could not be set, else nullptr.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogCallback(void (*callback)(const char* string, size_t length));

/**
* \param[in] result OgaResult to be destroyed.
*/
15 changes: 15 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
@@ -277,6 +277,20 @@ void SetLogOptions(const pybind11::kwargs& dict) {
}
}

void SetLogCallback(std::optional<const pybind11::function> callback) {
static std::optional<pybind11::function> log_callback;
log_callback = callback;

if (log_callback.has_value()) {
Oga::SetLogCallback([](const char* message, size_t length) {
pybind11::gil_scoped_acquire gil;
(*log_callback)(std::string_view(message, length));
});
} else {
Oga::SetLogCallback(nullptr);
}
}

PYBIND11_MODULE(onnxruntime_genai, m) {
m.doc() = R"pbdoc(
Ort Generators library
@@ -482,6 +496,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def("load", &OgaAdapters::LoadAdapter);

m.def("set_log_options", &SetLogOptions);
m.def("set_log_callback", &SetLogCallback);

m.def("is_cuda_available", []() { return USE_CUDA != 0; });
m.def("is_dml_available", []() { return USE_DML != 0; });
4 changes: 2 additions & 2 deletions src/tensor.cpp
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ Tensor::~Tensor() {

void Tensor::CreateTensor(std::span<const int64_t> shape, bool make_static) {
if (make_static) {
size_t new_bytes = SizeOf(type_) * ElementCountFromShape(shape);
size_t new_bytes = Ort::SizeOf(type_) * ElementCountFromShape(shape);
if (buffer_ == nullptr) {
bytes_ = new_bytes;
buffer_ = p_device_->GetAllocator().Alloc(bytes_);
@@ -40,7 +40,7 @@ void Tensor::MakeStatic() {
if (ort_tensor_ == nullptr) {
throw std::runtime_error("Tensor: MakeStatic called before CreateTensor");
}
size_t new_bytes = GetElementCount() * SizeOf(type_);
size_t new_bytes = GetElementCount() * Ort::SizeOf(type_);
if (buffer_ == nullptr) {
buffer_ = p_device_->GetAllocator().Alloc(new_bytes);
bytes_ = new_bytes;
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.