diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.h b/lldb/include/lldb/Protocol/MCP/MCPError.h similarity index 78% rename from lldb/source/Plugins/Protocol/MCP/MCPError.h rename to lldb/include/lldb/Protocol/MCP/MCPError.h index f4db13d6deade..55dd40f124a15 100644 --- a/lldb/source/Plugins/Protocol/MCP/MCPError.h +++ b/lldb/include/lldb/Protocol/MCP/MCPError.h @@ -1,4 +1,4 @@ -//===-- MCPError.h --------------------------------------------------------===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,14 @@ // //===----------------------------------------------------------------------===// -#include "Protocol.h" +#ifndef LLDB_PROTOCOL_MCP_MCPERROR_H +#define LLDB_PROTOCOL_MCP_MCPERROR_H + +#include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" -#include "llvm/Support/FormatVariadic.h" #include -namespace lldb_private::mcp { +namespace lldb_protocol::mcp { class MCPError : public llvm::ErrorInfo { public: @@ -24,7 +26,7 @@ class MCPError : public llvm::ErrorInfo { const std::string &getMessage() const { return m_message; } - protocol::Error toProtcolError() const; + lldb_protocol::mcp::Error toProtocolError() const; static constexpr int64_t kResourceNotFound = -32002; static constexpr int64_t kInternalError = -32603; @@ -47,4 +49,6 @@ class UnsupportedURI : public llvm::ErrorInfo { std::string m_uri; }; -} // namespace lldb_private::mcp +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h similarity index 73% rename from lldb/source/Plugins/Protocol/MCP/Protocol.h rename to lldb/include/lldb/Protocol/MCP/Protocol.h index ce74836e62541..49f9490221755 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -11,62 +11,88 @@ // //===----------------------------------------------------------------------===// -#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H -#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H +#ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H +#define LLDB_PROTOCOL_MCP_PROTOCOL_H #include "llvm/Support/JSON.h" #include #include #include -namespace lldb_private::mcp::protocol { +namespace lldb_protocol::mcp { -static llvm::StringLiteral kVersion = "2024-11-05"; +static llvm::StringLiteral kProtocolVersion = "2024-11-05"; + +/// A Request or Response 'id'. +/// +/// NOTE: This differs from the JSON-RPC 2.0 spec. The MCP spec says this must +/// be a string or number, excluding a json 'null' as a valid id. +using Id = std::variant; /// A request that expects a response. struct Request { - uint64_t id = 0; + /// The request id. + Id id = 0; + /// The method to be invoked. std::string method; + /// The method's params. std::optional params; }; llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); +bool operator==(const Request &, const Request &); -struct ErrorInfo { +struct Error { + /// The error type that occurred. int64_t code = 0; + /// A short description of the error. The message SHOULD be limited to a + /// concise single sentence. std::string message; - std::string data; -}; - -llvm::json::Value toJSON(const ErrorInfo &); -bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path); - -struct Error { - uint64_t id = 0; - ErrorInfo error; + /// Additional information about the error. The value of this member is + /// defined by the sender (e.g. detailed error information, nested errors + /// etc.). + std::optional data; }; llvm::json::Value toJSON(const Error &); bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); +bool operator==(const Error &, const Error &); +/// A response to a request, either an error or a result. struct Response { - uint64_t id = 0; - std::optional result; - std::optional error; + /// The request id. + Id id = 0; + /// The result of the request, either an Error or the JSON value of the + /// response. + std::variant result; }; llvm::json::Value toJSON(const Response &); bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); +bool operator==(const Response &, const Response &); /// A notification which does not expect a response. struct Notification { + /// The method to be invoked. std::string method; + /// The notification's params. std::optional params; }; llvm::json::Value toJSON(const Notification &); bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); +bool operator==(const Notification &, const Notification &); + +/// A general message as defined by the JSON-RPC 2.0 spec. +using Message = std::variant; +// With clang-cl and MSVC STL 202208, convertible can be false later if we do +// not force it to be checked early here. +static_assert(std::is_convertible_v, + "Message is not convertible to itself"); + +bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); +llvm::json::Value toJSON(const Message &); struct ToolCapability { /// Whether this server supports notifications for changes to the tool list. @@ -176,13 +202,8 @@ struct ToolDefinition { llvm::json::Value toJSON(const ToolDefinition &); bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); -using Message = std::variant; - -bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); -llvm::json::Value toJSON(const Message &); - using ToolArguments = std::variant; -} // namespace lldb_private::mcp::protocol +} // namespace lldb_protocol::mcp #endif diff --git a/lldb/include/lldb/Protocol/MCP/Resource.h b/lldb/include/lldb/Protocol/MCP/Resource.h new file mode 100644 index 0000000000000..4835d340cd4c6 --- /dev/null +++ b/lldb/include/lldb/Protocol/MCP/Resource.h @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PROTOCOL_MCP_RESOURCE_H +#define LLDB_PROTOCOL_MCP_RESOURCE_H + +#include "lldb/Protocol/MCP/Protocol.h" +#include + +namespace lldb_protocol::mcp { + +class ResourceProvider { +public: + ResourceProvider() = default; + virtual ~ResourceProvider() = default; + + virtual std::vector GetResources() const = 0; + virtual llvm::Expected + ReadResource(llvm::StringRef uri) const = 0; +}; + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h new file mode 100644 index 0000000000000..2ac05880de86b --- /dev/null +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PROTOCOL_MCP_SERVER_H +#define LLDB_PROTOCOL_MCP_SERVER_H + +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Resource.h" +#include "lldb/Protocol/MCP/Tool.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/Error.h" +#include + +namespace lldb_protocol::mcp { + +class Server { +public: + Server(std::string name, std::string version); + virtual ~Server() = default; + + void AddTool(std::unique_ptr tool); + void AddResourceProvider(std::unique_ptr resource_provider); + +protected: + virtual Capabilities GetCapabilities() = 0; + + using RequestHandler = + std::function(const Request &)>; + using NotificationHandler = std::function; + + void AddRequestHandlers(); + + void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + + llvm::Expected> HandleData(llvm::StringRef data); + + llvm::Expected Handle(Request request); + void Handle(Notification notification); + + llvm::Expected InitializeHandler(const Request &); + + llvm::Expected ToolsListHandler(const Request &); + llvm::Expected ToolsCallHandler(const Request &); + + llvm::Expected ResourcesListHandler(const Request &); + llvm::Expected ResourcesReadHandler(const Request &); + + std::mutex m_mutex; + +private: + const std::string m_name; + const std::string m_version; + + llvm::StringMap> m_tools; + std::vector> m_resource_providers; + + llvm::StringMap m_request_handlers; + llvm::StringMap m_notification_handlers; +}; + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/include/lldb/Protocol/MCP/Tool.h b/lldb/include/lldb/Protocol/MCP/Tool.h new file mode 100644 index 0000000000000..96669d1357166 --- /dev/null +++ b/lldb/include/lldb/Protocol/MCP/Tool.h @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PROTOCOL_MCP_TOOL_H +#define LLDB_PROTOCOL_MCP_TOOL_H + +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/JSON.h" +#include + +namespace lldb_protocol::mcp { + +class Tool { +public: + Tool(std::string name, std::string description); + virtual ~Tool() = default; + + virtual llvm::Expected + Call(const lldb_protocol::mcp::ToolArguments &args) = 0; + + virtual std::optional GetSchema() const { + return llvm::json::Object{{"type", "object"}}; + } + + lldb_protocol::mcp::ToolDefinition GetDefinition() const; + + const std::string &GetName() { return m_name; } + +private: + std::string m_name; + std::string m_description; +}; + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/source/CMakeLists.txt b/lldb/source/CMakeLists.txt index 51c9f9c90826e..ae02227ca3578 100644 --- a/lldb/source/CMakeLists.txt +++ b/lldb/source/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(Host) add_subdirectory(Initialization) add_subdirectory(Interpreter) add_subdirectory(Plugins) +add_subdirectory(Protocol) add_subdirectory(Symbol) add_subdirectory(Target) add_subdirectory(Utility) diff --git a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt index e104fb527e57a..87565e693158a 100644 --- a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt +++ b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt @@ -1,6 +1,4 @@ add_lldb_library(lldbPluginProtocolServerMCP PLUGIN - MCPError.cpp - Protocol.cpp ProtocolServerMCP.cpp Resource.cpp Tool.cpp @@ -10,5 +8,6 @@ add_lldb_library(lldbPluginProtocolServerMCP PLUGIN LINK_LIBS lldbHost + lldbProtocolMCP lldbUtility ) diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 0e5a3631e6387..c359663239dcc 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -7,8 +7,11 @@ //===----------------------------------------------------------------------===// #include "ProtocolServerMCP.h" -#include "MCPError.h" +#include "Resource.h" +#include "Tool.h" #include "lldb/Core/PluginManager.h" +#include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Tool.h" #include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" #include "llvm/ADT/StringExtras.h" @@ -18,34 +21,23 @@ using namespace lldb_private; using namespace lldb_private::mcp; +using namespace lldb_protocol::mcp; using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) static constexpr size_t kChunkSize = 1024; - -ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { - AddRequestHandler("initialize", - std::bind(&ProtocolServerMCP::InitializeHandler, this, - std::placeholders::_1)); - - AddRequestHandler("tools/list", - std::bind(&ProtocolServerMCP::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", - std::bind(&ProtocolServerMCP::ToolsCallHandler, this, - std::placeholders::_1)); - - AddRequestHandler("resources/list", - std::bind(&ProtocolServerMCP::ResourcesListHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/read", - std::bind(&ProtocolServerMCP::ResourcesReadHandler, this, - std::placeholders::_1)); - AddNotificationHandler( - "notifications/initialized", [](const protocol::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); - }); +static constexpr llvm::StringLiteral kName = "lldb-mcp"; +static constexpr llvm::StringLiteral kVersion = "0.1.0"; + +ProtocolServerMCP::ProtocolServerMCP() + : ProtocolServer(), + lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { + AddNotificationHandler("notifications/initialized", + [](const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), + "MCP initialization complete"); + }); AddTool( std::make_unique("lldb_command", "Run an lldb command.")); @@ -72,32 +64,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } -llvm::Expected -ProtocolServerMCP::Handle(protocol::Request request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return make_error( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void ProtocolServerMCP::Handle(protocol::Notification notification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } - - LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", - notification.method, notification.params); -} - void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", m_clients.size() + 1); @@ -111,7 +77,7 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { auto read_handle_up = m_loop.RegisterReadObject( io_sp, [this, client](MainLoopBase &loop) { - if (Error error = ReadCallback(*client)) { + if (llvm::Error error = ReadCallback(*client)) { LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); client->read_handle_up.reset(); } @@ -133,7 +99,7 @@ llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { for (std::string::size_type pos; (pos = client.buffer.find('\n')) != std::string::npos;) { - llvm::Expected> message = + llvm::Expected> message = HandleData(StringRef(client.buffer.data(), pos)); client.buffer = client.buffer.erase(0, pos + 1); if (!message) @@ -152,7 +118,7 @@ llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { - std::lock_guard guard(m_server_mutex); + std::lock_guard guard(m_mutex); if (m_running) return llvm::createStringError("the MCP server is already running"); @@ -184,7 +150,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::Error ProtocolServerMCP::Stop() { { - std::lock_guard guard(m_server_mutex); + std::lock_guard guard(m_mutex); if (!m_running) return createStringError("the MCP sever is not running"); m_running = false; @@ -199,7 +165,7 @@ llvm::Error ProtocolServerMCP::Stop() { m_loop_thread.join(); { - std::lock_guard guard(m_server_mutex); + std::lock_guard guard(m_mutex); m_listener.reset(); m_listen_handlers.clear(); m_clients.clear(); @@ -208,203 +174,11 @@ llvm::Error ProtocolServerMCP::Stop() { return llvm::Error::success(); } -llvm::Expected> -ProtocolServerMCP::HandleData(llvm::StringRef data) { - auto message = llvm::json::parse(/*JSON=*/data); - if (!message) - return message.takeError(); - - if (const protocol::Request *request = - std::get_if(&(*message))) { - llvm::Expected response = Handle(*request); - - // Handle failures by converting them into an Error message. - if (!response) { - protocol::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.error.code = MCPError::kInternalError; - protocol_error.error.message = err.message(); - }); - protocol_error.id = request->id; - return protocol_error; - } - - return *response; - } - - if (const protocol::Notification *notification = - std::get_if(&(*message))) { - Handle(*notification); - return std::nullopt; - } - - if (std::get_if(&(*message))) - return llvm::createStringError("unexpected MCP message: error"); - - if (std::get_if(&(*message))) - return llvm::createStringError("unexpected MCP message: response"); - - llvm_unreachable("all message types handled"); -} - -protocol::Capabilities ProtocolServerMCP::GetCapabilities() { - protocol::Capabilities capabilities; +lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { + lldb_protocol::mcp::Capabilities capabilities; capabilities.tools.listChanged = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. capabilities.resources.listChanged = false; return capabilities; } - -void ProtocolServerMCP::AddTool(std::unique_ptr tool) { - std::lock_guard guard(m_server_mutex); - - if (!tool) - return; - m_tools[tool->GetName()] = std::move(tool); -} - -void ProtocolServerMCP::AddResourceProvider( - std::unique_ptr resource_provider) { - std::lock_guard guard(m_server_mutex); - - if (!resource_provider) - return; - m_resource_providers.push_back(std::move(resource_provider)); -} - -void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, - RequestHandler handler) { - std::lock_guard guard(m_server_mutex); - m_request_handlers[method] = std::move(handler); -} - -void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - std::lock_guard guard(m_server_mutex); - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected -ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { - protocol::Response response; - response.result.emplace(llvm::json::Object{ - {"protocolVersion", protocol::kVersion}, - {"capabilities", GetCapabilities()}, - {"serverInfo", - llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); - return response; -} - -llvm::Expected -ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { - protocol::Response response; - - llvm::json::Array tools; - for (const auto &tool : m_tools) - tools.emplace_back(toJSON(tool.second->GetDefinition())); - - response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); - - return response; -} - -llvm::Expected -ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { - protocol::Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no tool parameters"); - - const json::Value *name = param_obj->get("name"); - if (!name) - return llvm::createStringError("no tool name"); - - llvm::StringRef tool_name = name->getAsString().value_or(""); - if (tool_name.empty()) - return llvm::createStringError("no tool name"); - - auto it = m_tools.find(tool_name); - if (it == m_tools.end()) - return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); - - protocol::ToolArguments tool_args; - if (const json::Value *args = param_obj->get("arguments")) - tool_args = *args; - - llvm::Expected text_result = - it->second->Call(tool_args); - if (!text_result) - return text_result.takeError(); - - response.result.emplace(toJSON(*text_result)); - - return response; -} - -llvm::Expected -ProtocolServerMCP::ResourcesListHandler(const protocol::Request &request) { - protocol::Response response; - - llvm::json::Array resources; - - std::lock_guard guard(m_server_mutex); - for (std::unique_ptr &resource_provider_up : - m_resource_providers) { - for (const protocol::Resource &resource : - resource_provider_up->GetResources()) - resources.push_back(resource); - } - response.result.emplace( - llvm::json::Object{{"resources", std::move(resources)}}); - - return response; -} - -llvm::Expected -ProtocolServerMCP::ResourcesReadHandler(const protocol::Request &request) { - protocol::Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no resource parameters"); - - const json::Value *uri = param_obj->get("uri"); - if (!uri) - return llvm::createStringError("no resource uri"); - - llvm::StringRef uri_str = uri->getAsString().value_or(""); - if (uri_str.empty()) - return llvm::createStringError("no resource uri"); - - std::lock_guard guard(m_server_mutex); - for (std::unique_ptr &resource_provider_up : - m_resource_providers) { - llvm::Expected result = - resource_provider_up->ReadResource(uri_str); - if (result.errorIsA()) { - llvm::consumeError(result.takeError()); - continue; - } - if (!result) - return result.takeError(); - - protocol::Response response; - response.result.emplace(std::move(*result)); - return response; - } - - return make_error( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), - MCPError::kResourceNotFound); -} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index e273f6e2a8d37..7fe909a728b85 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -9,18 +9,17 @@ #ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H #define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H -#include "Protocol.h" -#include "Resource.h" -#include "Tool.h" #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Socket.h" -#include "llvm/ADT/StringMap.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Server.h" #include namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer { +class ProtocolServerMCP : public ProtocolServer, + public lldb_protocol::mcp::Server { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -40,45 +39,10 @@ class ProtocolServerMCP : public ProtocolServer { Socket *GetSocket() const override { return m_listener.get(); } -protected: - using RequestHandler = std::function( - const protocol::Request &)>; - using NotificationHandler = - std::function; - - void AddTool(std::unique_ptr tool); - void AddResourceProvider(std::unique_ptr resource_provider); - - void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); - private: void AcceptCallback(std::unique_ptr socket); - llvm::Expected> - HandleData(llvm::StringRef data); - - llvm::Expected Handle(protocol::Request request); - void Handle(protocol::Notification notification); - - llvm::Expected - InitializeHandler(const protocol::Request &); - - llvm::Expected - ToolsListHandler(const protocol::Request &); - llvm::Expected - ToolsCallHandler(const protocol::Request &); - - llvm::Expected - ResourcesListHandler(const protocol::Request &); - llvm::Expected - ResourcesReadHandler(const protocol::Request &); - - protocol::Capabilities GetCapabilities(); - - llvm::StringLiteral kName = "lldb-mcp"; - llvm::StringLiteral kVersion = "0.1.0"; + lldb_protocol::mcp::Capabilities GetCapabilities() override; bool m_running = false; @@ -95,13 +59,6 @@ class ProtocolServerMCP : public ProtocolServer { }; llvm::Error ReadCallback(Client &client); std::vector> m_clients; - - std::mutex m_server_mutex; - llvm::StringMap> m_tools; - std::vector> m_resource_providers; - - llvm::StringMap m_request_handlers; - llvm::StringMap m_notification_handlers; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp index d75d5b6dd6a41..e94d2cdd65e07 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -5,12 +5,14 @@ //===----------------------------------------------------------------------===// #include "Resource.h" -#include "MCPError.h" #include "lldb/Core/Debugger.h" #include "lldb/Core/Module.h" +#include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Target/Platform.h" +using namespace lldb_private; using namespace lldb_private::mcp; +using namespace lldb_protocol::mcp; namespace { struct DebuggerResource { @@ -64,11 +66,11 @@ static llvm::Error createUnsupportedURIError(llvm::StringRef uri) { return llvm::make_error(uri.str()); } -protocol::Resource +lldb_protocol::mcp::Resource DebuggerResourceProvider::GetDebuggerResource(Debugger &debugger) { const lldb::user_id_t debugger_id = debugger.GetID(); - protocol::Resource resource; + lldb_protocol::mcp::Resource resource; resource.uri = llvm::formatv("lldb://debugger/{0}", debugger_id); resource.name = debugger.GetInstanceName(); resource.description = @@ -78,7 +80,7 @@ DebuggerResourceProvider::GetDebuggerResource(Debugger &debugger) { return resource; } -protocol::Resource +lldb_protocol::mcp::Resource DebuggerResourceProvider::GetTargetResource(size_t target_idx, Target &target) { const size_t debugger_id = target.GetDebugger().GetID(); @@ -87,7 +89,7 @@ DebuggerResourceProvider::GetTargetResource(size_t target_idx, Target &target) { if (Module *exe_module = target.GetExecutableModulePointer()) target_name = exe_module->GetFileSpec().GetFilename().GetString(); - protocol::Resource resource; + lldb_protocol::mcp::Resource resource; resource.uri = llvm::formatv("lldb://debugger/{0}/target/{1}", debugger_id, target_idx); resource.name = target_name; @@ -98,8 +100,9 @@ DebuggerResourceProvider::GetTargetResource(size_t target_idx, Target &target) { return resource; } -std::vector DebuggerResourceProvider::GetResources() const { - std::vector resources; +std::vector +DebuggerResourceProvider::GetResources() const { + std::vector resources; const size_t num_debuggers = Debugger::GetNumDebuggers(); for (size_t i = 0; i < num_debuggers; ++i) { @@ -121,7 +124,7 @@ std::vector DebuggerResourceProvider::GetResources() const { return resources; } -llvm::Expected +llvm::Expected DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { auto [protocol, path] = uri.split("://"); @@ -158,7 +161,7 @@ DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { return ReadDebuggerResource(uri, debugger_idx); } -llvm::Expected +llvm::Expected DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id) { lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); @@ -170,17 +173,17 @@ DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, debugger_resource.name = debugger_sp->GetInstanceName(); debugger_resource.num_targets = debugger_sp->GetTargetList().GetNumTargets(); - protocol::ResourceContents contents; + lldb_protocol::mcp::ResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); - protocol::ResourceResult result; + lldb_protocol::mcp::ResourceResult result; result.contents.push_back(contents); return result; } -llvm::Expected +llvm::Expected DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx) { @@ -206,12 +209,12 @@ DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, if (lldb::PlatformSP platform_sp = target_sp->GetPlatform()) target_resource.platform = platform_sp->GetName(); - protocol::ResourceContents contents; + lldb_protocol::mcp::ResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(target_resource)); - protocol::ResourceResult result; + lldb_protocol::mcp::ResourceResult result; result.contents.push_back(contents); return result; } diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h index 5ac38e7e878ff..e2382a74f796b 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.h +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -9,39 +9,31 @@ #ifndef LLDB_PLUGINS_PROTOCOL_MCP_RESOURCE_H #define LLDB_PLUGINS_PROTOCOL_MCP_RESOURCE_H -#include "Protocol.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Resource.h" #include "lldb/lldb-private.h" #include namespace lldb_private::mcp { -class ResourceProvider { -public: - ResourceProvider() = default; - virtual ~ResourceProvider() = default; - - virtual std::vector GetResources() const = 0; - virtual llvm::Expected - ReadResource(llvm::StringRef uri) const = 0; -}; - -class DebuggerResourceProvider : public ResourceProvider { +class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { public: using ResourceProvider::ResourceProvider; virtual ~DebuggerResourceProvider() = default; - virtual std::vector GetResources() const override; - virtual llvm::Expected + virtual std::vector + GetResources() const override; + virtual llvm::Expected ReadResource(llvm::StringRef uri) const override; private: - static protocol::Resource GetDebuggerResource(Debugger &debugger); - static protocol::Resource GetTargetResource(size_t target_idx, - Target &target); + static lldb_protocol::mcp::Resource GetDebuggerResource(Debugger &debugger); + static lldb_protocol::mcp::Resource GetTargetResource(size_t target_idx, + Target &target); - static llvm::Expected + static llvm::Expected ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); - static llvm::Expected + static llvm::Expected ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx); }; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index bbc19a1e51942..143470702a6fd 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -11,6 +11,8 @@ #include "lldb/Interpreter/CommandInterpreter.h" #include "lldb/Interpreter/CommandReturnObject.h" +using namespace lldb_private; +using namespace lldb_protocol; using namespace lldb_private::mcp; using namespace llvm; @@ -28,33 +30,19 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, } /// Helper function to create a TextResult from a string output. -static lldb_private::mcp::protocol::TextResult -createTextResult(std::string output, bool is_error = false) { - lldb_private::mcp::protocol::TextResult text_result; +static lldb_protocol::mcp::TextResult createTextResult(std::string output, + bool is_error = false) { + lldb_protocol::mcp::TextResult text_result; text_result.content.emplace_back( - lldb_private::mcp::protocol::TextContent{{std::move(output)}}); + lldb_protocol::mcp::TextContent{{std::move(output)}}); text_result.isError = is_error; return text_result; } } // namespace -Tool::Tool(std::string name, std::string description) - : m_name(std::move(name)), m_description(std::move(description)) {} - -protocol::ToolDefinition Tool::GetDefinition() const { - protocol::ToolDefinition definition; - definition.name = m_name; - definition.description = m_description; - - if (std::optional input_schema = GetSchema()) - definition.inputSchema = *input_schema; - - return definition; -} - -llvm::Expected -CommandTool::Call(const protocol::ToolArguments &args) { +llvm::Expected +CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) { if (!std::holds_alternative(args)) return createStringError("CommandTool requires arguments"); diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index d0f639adad24e..b7b1756eb38d7 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -9,41 +9,21 @@ #ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H #define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H -#include "Protocol.h" #include "lldb/Core/Debugger.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Tool.h" #include "llvm/Support/JSON.h" #include namespace lldb_private::mcp { -class Tool { +class CommandTool : public lldb_protocol::mcp::Tool { public: - Tool(std::string name, std::string description); - virtual ~Tool() = default; - - virtual llvm::Expected - Call(const protocol::ToolArguments &args) = 0; - - virtual std::optional GetSchema() const { - return llvm::json::Object{{"type", "object"}}; - } - - protocol::ToolDefinition GetDefinition() const; - - const std::string &GetName() { return m_name; } - -private: - std::string m_name; - std::string m_description; -}; - -class CommandTool : public mcp::Tool { -public: - using mcp::Tool::Tool; + using lldb_protocol::mcp::Tool::Tool; ~CommandTool() = default; - virtual llvm::Expected - Call(const protocol::ToolArguments &args) override; + virtual llvm::Expected + Call(const lldb_protocol::mcp::ToolArguments &args) override; virtual std::optional GetSchema() const override; }; diff --git a/lldb/source/Protocol/CMakeLists.txt b/lldb/source/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..93b347d4cc9d8 --- /dev/null +++ b/lldb/source/Protocol/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MCP) diff --git a/lldb/source/Protocol/MCP/CMakeLists.txt b/lldb/source/Protocol/MCP/CMakeLists.txt new file mode 100644 index 0000000000000..a73e7e6a7cab1 --- /dev/null +++ b/lldb/source/Protocol/MCP/CMakeLists.txt @@ -0,0 +1,12 @@ +add_lldb_library(lldbProtocolMCP NO_PLUGIN_DEPENDENCIES + MCPError.cpp + Protocol.cpp + Server.cpp + Tool.cpp + + LINK_COMPONENTS + Support + LINK_LIBS + lldbUtility +) + diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp similarity index 77% rename from lldb/source/Plugins/Protocol/MCP/MCPError.cpp rename to lldb/source/Protocol/MCP/MCPError.cpp index 659b53a14fe23..e140d11e12cfe 100644 --- a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -1,4 +1,4 @@ -//===-- MCPError.cpp ------------------------------------------------------===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "MCPError.h" +#include "lldb/Protocol/MCP/MCPError.h" #include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" #include -namespace lldb_private::mcp { +using namespace lldb_protocol::mcp; char MCPError::ID; char UnsupportedURI::ID; @@ -25,10 +25,10 @@ std::error_code MCPError::convertToErrorCode() const { return llvm::inconvertibleErrorCode(); } -protocol::Error MCPError::toProtcolError() const { - protocol::Error error; - error.error.code = m_error_code; - error.error.message = m_message; +lldb_protocol::mcp::Error MCPError::toProtocolError() const { + lldb_protocol::mcp::Error error; + error.code = m_error_code; + error.message = m_message; return error; } @@ -41,5 +41,3 @@ void UnsupportedURI::log(llvm::raw_ostream &OS) const { std::error_code UnsupportedURI::convertToErrorCode() const { return llvm::inconvertibleErrorCode(); } - -} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp similarity index 69% rename from lldb/source/Plugins/Protocol/MCP/Protocol.cpp rename to lldb/source/Protocol/MCP/Protocol.cpp index 274ba6fac01ec..d9b11bd766686 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "Protocol.h" +#include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/JSON.h" using namespace llvm; -namespace lldb_private::mcp::protocol { +namespace lldb_protocol::mcp { static bool mapRaw(const json::Value &Params, StringLiteral Prop, std::optional &V, json::Path P) { @@ -26,8 +26,45 @@ static bool mapRaw(const json::Value &Params, StringLiteral Prop, return true; } +static llvm::json::Value toJSON(const Id &Id) { + if (const int64_t *I = std::get_if(&Id)) + return json::Value(*I); + if (const std::string *S = std::get_if(&Id)) + return json::Value(*S); + llvm_unreachable("unexpected type in protocol::Id"); +} + +static bool mapId(const llvm::json::Value &V, StringLiteral Prop, Id &Id, + llvm::json::Path P) { + const auto *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + const auto *E = O->get(Prop); + if (!E) { + P.field(Prop).report("not found"); + return false; + } + + if (auto S = E->getAsString()) { + Id = S->str(); + return true; + } + + if (auto I = E->getAsInteger()) { + Id = *I; + return true; + } + + P.report("expected string or number"); + return false; +} + llvm::json::Value toJSON(const Request &R) { - json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + json::Object Result{ + {"jsonrpc", "2.0"}, {"id", toJSON(R.id)}, {"method", R.method}}; if (R.params) Result.insert({"params", R.params}); return Result; @@ -35,47 +72,75 @@ llvm::json::Value toJSON(const Request &R) { bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - if (!O || !O.map("id", R.id) || !O.map("method", R.method)) - return false; - return mapRaw(V, "params", R.params, P); + return O && mapId(V, "id", R.id, P) && O.map("method", R.method) && + mapRaw(V, "params", R.params, P); } -llvm::json::Value toJSON(const ErrorInfo &EI) { - llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; - if (!EI.data.empty()) - Result.insert({"data", EI.data}); - return Result; -} - -bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("code", EI.code) && O.map("message", EI.message) && - O.mapOptional("data", EI.data); +bool operator==(const Request &a, const Request &b) { + return a.id == b.id && a.method == b.method && a.params == b.params; } llvm::json::Value toJSON(const Error &E) { - return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}}; + llvm::json::Object Result{{"code", E.code}, {"message", E.message}}; + if (E.data) + Result.insert({"data", *E.data}); + return Result; } bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - return O && O.map("id", E.id) && O.map("error", E.error); + return O && O.map("code", E.code) && O.map("message", E.message) && + mapRaw(V, "data", E.data, P); +} + +bool operator==(const Error &a, const Error &b) { + return a.code == b.code && a.message == b.message && a.data == b.data; } llvm::json::Value toJSON(const Response &R) { - llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; - if (R.result) - Result.insert({"result", R.result}); - if (R.error) - Result.insert({"error", R.error}); + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}}; + + if (const Error *error = std::get_if(&R.result)) + Result.insert({"error", *error}); + if (const json::Value *result = std::get_if(&R.result)) + Result.insert({"result", *result}); return Result; } bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - if (!O || !O.map("id", R.id) || !O.map("error", R.error)) + const json::Object *E = V.getAsObject(); + if (!E) { + P.report("expected object"); + return false; + } + + const json::Value *result = E->get("result"); + const json::Value *raw_error = E->get("error"); + + if (result && raw_error) { + P.report("'result' and 'error' fields are mutually exclusive"); + return false; + } + + if (!result && !raw_error) { + P.report("'result' or 'error' fields are required'"); return false; - return mapRaw(V, "result", R.result, P); + } + + if (result) { + R.result = std::move(*result); + } else { + Error error; + if (!fromJSON(*raw_error, error, P)) + return false; + R.result = std::move(error); + } + + return mapId(V, "id", R.id, P); +} + +bool operator==(const Response &a, const Response &b) { + return a.id == b.id && a.result == b.result; } llvm::json::Value toJSON(const Notification &N) { @@ -97,6 +162,10 @@ bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { return true; } +bool operator==(const Notification &a, const Notification &b) { + return a.method == b.method && a.params == b.params; +} + llvm::json::Value toJSON(const ToolCapability &TC) { return llvm::json::Object{{"listChanged", TC.listChanged}}; } @@ -228,31 +297,23 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { // A message without an ID is a Notification. if (!O->get("id")) { - protocol::Notification N; + Notification N; if (!fromJSON(V, N, P)) return false; M = std::move(N); return true; } - if (O->get("error")) { - protocol::Error E; - if (!fromJSON(V, E, P)) - return false; - M = std::move(E); - return true; - } - - if (O->get("result")) { - protocol::Response R; + if (O->get("method")) { + Request R; if (!fromJSON(V, R, P)) return false; M = std::move(R); return true; } - if (O->get("method")) { - protocol::Request R; + if (O->get("result") || O->get("error")) { + Response R; if (!fromJSON(V, R, P)) return false; M = std::move(R); @@ -263,4 +324,4 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return false; } -} // namespace lldb_private::mcp::protocol +} // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp new file mode 100644 index 0000000000000..a9c1482e3e378 --- /dev/null +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -0,0 +1,234 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Protocol/MCP/Server.h" +#include "lldb/Protocol/MCP/MCPError.h" + +using namespace lldb_protocol::mcp; +using namespace llvm; + +Server::Server(std::string name, std::string version) + : m_name(std::move(name)), m_version(std::move(version)) { + AddRequestHandlers(); +} + +void Server::AddRequestHandlers() { + AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, + std::placeholders::_1)); + AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, + this, std::placeholders::_1)); + AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, + this, std::placeholders::_1)); +} + +llvm::Expected Server::Handle(Request request) { + auto it = m_request_handlers.find(request.method); + if (it != m_request_handlers.end()) { + llvm::Expected response = it->second(request); + if (!response) + return response; + response->id = request.id; + return *response; + } + + return llvm::make_error( + llvm::formatv("no handler for request: {0}", request.method).str()); +} + +void Server::Handle(Notification notification) { + auto it = m_notification_handlers.find(notification.method); + if (it != m_notification_handlers.end()) { + it->second(notification); + return; + } +} + +llvm::Expected> +Server::HandleData(llvm::StringRef data) { + auto message = llvm::json::parse(/*JSON=*/data); + if (!message) + return message.takeError(); + + if (const Request *request = std::get_if(&(*message))) { + llvm::Expected response = Handle(*request); + + // Handle failures by converting them into an Error message. + if (!response) { + Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + Response error_response; + error_response.id = request->id; + error_response.result = std::move(protocol_error); + return error_response; + } + + return *response; + } + + if (const Notification *notification = + std::get_if(&(*message))) { + Handle(*notification); + return std::nullopt; + } + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: response"); + + llvm_unreachable("all message types handled"); +} + +void Server::AddTool(std::unique_ptr tool) { + std::lock_guard guard(m_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void Server::AddResourceProvider( + std::unique_ptr resource_provider) { + std::lock_guard guard(m_mutex); + + if (!resource_provider) + return; + m_resource_providers.push_back(std::move(resource_provider)); +} + +void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { + std::lock_guard guard(m_mutex); + m_request_handlers[method] = std::move(handler); +} + +void Server::AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler) { + std::lock_guard guard(m_mutex); + m_notification_handlers[method] = std::move(handler); +} + +llvm::Expected Server::InitializeHandler(const Request &request) { + Response response; + response.result = llvm::json::Object{ + {"protocolVersion", mcp::kProtocolVersion}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", m_name}, {"version", m_version}}}}; + return response; +} + +llvm::Expected Server::ToolsListHandler(const Request &request) { + Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result = llvm::json::Object{{"tools", std::move(tools)}}; + + return response; +} + +llvm::Expected Server::ToolsCallHandler(const Request &request) { + Response response; + + if (!request.params) + return llvm::createStringError("no tool parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no tool parameters"); + + const json::Value *name = param_obj->get("name"); + if (!name) + return llvm::createStringError("no tool name"); + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return llvm::createStringError("no tool name"); + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); + + ToolArguments tool_args; + if (const json::Value *args = param_obj->get("arguments")) + tool_args = *args; + + llvm::Expected text_result = it->second->Call(tool_args); + if (!text_result) + return text_result.takeError(); + + response.result = toJSON(*text_result); + + return response; +} + +llvm::Expected Server::ResourcesListHandler(const Request &request) { + Response response; + + llvm::json::Array resources; + + std::lock_guard guard(m_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + for (const Resource &resource : resource_provider_up->GetResources()) + resources.push_back(resource); + } + response.result = llvm::json::Object{{"resources", std::move(resources)}}; + + return response; +} + +llvm::Expected Server::ResourcesReadHandler(const Request &request) { + Response response; + + if (!request.params) + return llvm::createStringError("no resource parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no resource parameters"); + + const json::Value *uri = param_obj->get("uri"); + if (!uri) + return llvm::createStringError("no resource uri"); + + llvm::StringRef uri_str = uri->getAsString().value_or(""); + if (uri_str.empty()) + return llvm::createStringError("no resource uri"); + + std::lock_guard guard(m_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + llvm::Expected result = + resource_provider_up->ReadResource(uri_str); + if (result.errorIsA()) { + llvm::consumeError(result.takeError()); + continue; + } + if (!result) + return result.takeError(); + + Response response; + response.result = std::move(*result); + return response; + } + + return make_error( + llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + MCPError::kResourceNotFound); +} diff --git a/lldb/source/Protocol/MCP/Tool.cpp b/lldb/source/Protocol/MCP/Tool.cpp new file mode 100644 index 0000000000000..8e01f2bd5908b --- /dev/null +++ b/lldb/source/Protocol/MCP/Tool.cpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Protocol/MCP/Tool.h" + +using namespace lldb_protocol::mcp; + +Tool::Tool(std::string name, std::string description) + : m_name(std::move(name)), m_description(std::move(description)) {} + +lldb_protocol::mcp::ToolDefinition Tool::GetDefinition() const { + lldb_protocol::mcp::ToolDefinition definition; + definition.name = m_name; + definition.description = m_description; + + if (std::optional input_schema = GetSchema()) + definition.inputSchema = *input_schema; + + return definition; +} diff --git a/lldb/unittests/CMakeLists.txt b/lldb/unittests/CMakeLists.txt index 48fcfc40b73ab..8a20839a37469 100644 --- a/lldb/unittests/CMakeLists.txt +++ b/lldb/unittests/CMakeLists.txt @@ -67,26 +67,27 @@ add_subdirectory(Disassembler) add_subdirectory(Editline) add_subdirectory(Expression) add_subdirectory(Host) -add_subdirectory(Interpreter) add_subdirectory(Instruction) +add_subdirectory(Interpreter) add_subdirectory(Language) add_subdirectory(ObjectFile) add_subdirectory(Platform) add_subdirectory(Process) +add_subdirectory(Protocol) add_subdirectory(ScriptInterpreter) add_subdirectory(Signals) add_subdirectory(StackID) add_subdirectory(Symbol) add_subdirectory(SymbolFile) add_subdirectory(Target) -add_subdirectory(tools) +add_subdirectory(Thread) add_subdirectory(UnwindAssembly) add_subdirectory(Utility) -add_subdirectory(Thread) add_subdirectory(ValueObject) +add_subdirectory(tools) if(LLDB_ENABLE_PROTOCOL_SERVERS) - add_subdirectory(Protocol) + add_subdirectory(ProtocolServer) endif() if(LLDB_CAN_USE_DEBUGSERVER AND LLDB_TOOL_DEBUGSERVER_BUILD AND NOT LLDB_USE_SYSTEM_DEBUGSERVER) diff --git a/lldb/unittests/Protocol/CMakeLists.txt b/lldb/unittests/Protocol/CMakeLists.txt index 801662b0544d8..bbac69611e011 100644 --- a/lldb/unittests/Protocol/CMakeLists.txt +++ b/lldb/unittests/Protocol/CMakeLists.txt @@ -1,12 +1,9 @@ add_lldb_unittest(ProtocolTests ProtocolMCPTest.cpp - ProtocolMCPServerTest.cpp LINK_LIBS - lldbCore - lldbUtility lldbHost - lldbPluginPlatformMacOSX - lldbPluginProtocolServerMCP + lldbProtocolMCP + lldbUtility LLVMTestingSupport ) diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index ce8120cbfe9b9..ea19922522ffe 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -6,14 +6,14 @@ // //===----------------------------------------------------------------------===// -#include "Plugins/Protocol/MCP/Protocol.h" #include "TestingSupport/TestUtilities.h" +#include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" using namespace lldb; using namespace lldb_private; -using namespace lldb_private::mcp::protocol; +using namespace lldb_protocol::mcp; TEST(ProtocolMCPTest, Request) { Request request; @@ -149,9 +149,7 @@ TEST(ProtocolMCPTest, MessageWithRequest) { const Request &deserialized_request = std::get(*deserialized_message); - EXPECT_EQ(request.id, deserialized_request.id); - EXPECT_EQ(request.method, deserialized_request.method); - EXPECT_EQ(request.params, deserialized_request.params); + EXPECT_EQ(request, deserialized_request); } TEST(ProtocolMCPTest, MessageWithResponse) { @@ -168,8 +166,7 @@ TEST(ProtocolMCPTest, MessageWithResponse) { const Response &deserialized_response = std::get(*deserialized_message); - EXPECT_EQ(response.id, deserialized_response.id); - EXPECT_EQ(response.result, deserialized_response.result); + EXPECT_EQ(response, deserialized_response); } TEST(ProtocolMCPTest, MessageWithNotification) { @@ -186,49 +183,28 @@ TEST(ProtocolMCPTest, MessageWithNotification) { const Notification &deserialized_notification = std::get(*deserialized_message); - EXPECT_EQ(notification.method, deserialized_notification.method); - EXPECT_EQ(notification.params, deserialized_notification.params); + EXPECT_EQ(notification, deserialized_notification); } -TEST(ProtocolMCPTest, MessageWithError) { - ErrorInfo error_info; - error_info.code = -32603; - error_info.message = "Internal error"; - +TEST(ProtocolMCPTest, MessageWithErrorResponse) { Error error; - error.id = 3; - error.error = error_info; + error.code = -32603; + error.message = "Internal error"; + + Response error_response; + error_response.id = 3; + error_response.result = error; - Message message = error; + Message message = error_response; llvm::Expected deserialized_message = roundtripJSON(message); ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); - ASSERT_TRUE(std::holds_alternative(*deserialized_message)); - const Error &deserialized_error = std::get(*deserialized_message); - - EXPECT_EQ(error.id, deserialized_error.id); - EXPECT_EQ(error.error.code, deserialized_error.error.code); - EXPECT_EQ(error.error.message, deserialized_error.error.message); -} - -TEST(ProtocolMCPTest, ResponseWithError) { - ErrorInfo error_info; - error_info.code = -32700; - error_info.message = "Parse error"; - - Response response; - response.id = 4; - response.error = error_info; - - llvm::Expected deserialized_response = roundtripJSON(response); - ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + ASSERT_TRUE(std::holds_alternative(*deserialized_message)); + const Response &deserialized_error = + std::get(*deserialized_message); - EXPECT_EQ(response.id, deserialized_response->id); - EXPECT_FALSE(deserialized_response->result.has_value()); - ASSERT_TRUE(deserialized_response->error.has_value()); - EXPECT_EQ(response.error->code, deserialized_response->error->code); - EXPECT_EQ(response.error->message, deserialized_response->error->message); + EXPECT_EQ(error_response, deserialized_error); } TEST(ProtocolMCPTest, Resource) { diff --git a/lldb/unittests/ProtocolServer/CMakeLists.txt b/lldb/unittests/ProtocolServer/CMakeLists.txt new file mode 100644 index 0000000000000..6117430b35bf0 --- /dev/null +++ b/lldb/unittests/ProtocolServer/CMakeLists.txt @@ -0,0 +1,11 @@ +add_lldb_unittest(ProtocolServerTests + ProtocolMCPServerTest.cpp + + LINK_LIBS + lldbCore + lldbUtility + lldbHost + lldbPluginPlatformMacOSX + lldbPluginProtocolServerMCP + LLVMTestingSupport + ) diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp similarity index 90% rename from lldb/unittests/Protocol/ProtocolMCPServerTest.cpp rename to lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 51eb6275e811a..7890d3f69b9e1 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -7,22 +7,24 @@ //===----------------------------------------------------------------------===// #include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" -#include "Plugins/Protocol/MCP/MCPError.h" #include "Plugins/Protocol/MCP/ProtocolServerMCP.h" #include "TestingSupport/Host/SocketTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" +#include "lldb/Core/Debugger.h" #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" #include "lldb/Host/JSONTransport.h" #include "lldb/Host/Socket.h" +#include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" using namespace llvm; using namespace lldb; using namespace lldb_private; -using namespace lldb_private::mcp::protocol; +using namespace lldb_protocol::mcp; namespace { class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { @@ -43,12 +45,11 @@ class TestJSONTransport : public lldb_private::JSONRPCTransport { }; /// Test tool that returns it argument as text. -class TestTool : public mcp::Tool { +class TestTool : public Tool { public: - using mcp::Tool::Tool; + using Tool::Tool; - virtual llvm::Expected - Call(const ToolArguments &args) override { + virtual llvm::Expected Call(const ToolArguments &args) override { std::string argument; if (const json::Object *args_obj = std::get(args).getAsObject()) { @@ -57,14 +58,14 @@ class TestTool : public mcp::Tool { } } - mcp::protocol::TextResult text_result; - text_result.content.emplace_back(mcp::protocol::TextContent{{argument}}); + TextResult text_result; + text_result.content.emplace_back(TextContent{{argument}}); return text_result; } }; -class TestResourceProvider : public mcp::ResourceProvider { - using mcp::ResourceProvider::ResourceProvider; +class TestResourceProvider : public ResourceProvider { + using ResourceProvider::ResourceProvider; virtual std::vector GetResources() const override { std::vector resources; @@ -82,7 +83,7 @@ class TestResourceProvider : public mcp::ResourceProvider { virtual llvm::Expected ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") - return llvm::make_error(uri.str()); + return llvm::make_error(uri.str()); ResourceContents contents; contents.uri = "lldb://foo/bar"; @@ -96,25 +97,23 @@ class TestResourceProvider : public mcp::ResourceProvider { }; /// Test tool that returns an error. -class ErrorTool : public mcp::Tool { +class ErrorTool : public Tool { public: - using mcp::Tool::Tool; + using Tool::Tool; - virtual llvm::Expected - Call(const ToolArguments &args) override { + virtual llvm::Expected Call(const ToolArguments &args) override { return llvm::createStringError("error"); } }; /// Test tool that fails but doesn't return an error. -class FailTool : public mcp::Tool { +class FailTool : public Tool { public: - using mcp::Tool::Tool; + using Tool::Tool; - virtual llvm::Expected - Call(const ToolArguments &args) override { - mcp::protocol::TextResult text_result; - text_result.content.emplace_back(mcp::protocol::TextContent{{"failed"}}); + virtual llvm::Expected Call(const ToolArguments &args) override { + TextResult text_result; + text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; return text_result; } @@ -179,7 +178,7 @@ class ProtocolServerMCPTest : public ::testing::Test { } // namespace -TEST_F(ProtocolServerMCPTest, Intialization) { +TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral request = R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; llvm::StringLiteral response = @@ -309,8 +308,7 @@ TEST_F(ProtocolServerMCPTest, NotificationInitialized) { std::mutex mutex; m_server_up->AddNotificationHandler( - "notifications/initialized", - [&](const mcp::protocol::Notification ¬ification) { + "notifications/initialized", [&](const Notification ¬ification) { { std::lock_guard lock(mutex); handler_called = true; diff --git a/lldb/unittests/TestingSupport/TestUtilities.h b/lldb/unittests/TestingSupport/TestUtilities.h index db62881872fef..cc93a68a6a431 100644 --- a/lldb/unittests/TestingSupport/TestUtilities.h +++ b/lldb/unittests/TestingSupport/TestUtilities.h @@ -11,11 +11,11 @@ #include "lldb/Core/ModuleSpec.h" #include "lldb/Utility/DataBuffer.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include #define ASSERT_NO_ERROR(x) \ @@ -61,12 +61,10 @@ class TestFile { }; template static llvm::Expected roundtripJSON(const T &input) { - llvm::json::Value value = toJSON(input); - llvm::json::Path::Root root; - T output; - if (!fromJSON(value, output, root)) - return root.getError(); - return output; + std::string encoded; + llvm::raw_string_ostream OS(encoded); + OS << toJSON(input); + return llvm::json::parse(encoded); } } // namespace lldb_private