diff --git a/lldb/include/lldb/API/SBSymbol.h b/lldb/include/lldb/API/SBSymbol.h index 94521881f82f9..a93bc7a7ae074 100644 --- a/lldb/include/lldb/API/SBSymbol.h +++ b/lldb/include/lldb/API/SBSymbol.h @@ -85,6 +85,12 @@ class LLDB_API SBSymbol { SymbolType GetType(); + /// Get the ID of this symbol, usually the original symbol table index. + /// + /// \returns + /// Returns the ID of this symbol. + uint32_t GetID(); + bool operator==(const lldb::SBSymbol &rhs) const; bool operator!=(const lldb::SBSymbol &rhs) const; @@ -99,6 +105,15 @@ class LLDB_API SBSymbol { // other than the actual symbol table itself in the object file. bool IsSynthetic(); + /// Returns true if the symbol is a debug symbol. + bool IsDebug(); + + /// Get the string representation of a symbol type. + static const char *GetTypeAsString(lldb::SymbolType symbol_type); + + /// Get the symbol type from a string representation. + static lldb::SymbolType GetTypeFromString(const char *str); + protected: lldb_private::Symbol *get(); diff --git a/lldb/include/lldb/API/SBTarget.h b/lldb/include/lldb/API/SBTarget.h index 4381781383075..f9e0d2681cdee 100644 --- a/lldb/include/lldb/API/SBTarget.h +++ b/lldb/include/lldb/API/SBTarget.h @@ -324,6 +324,16 @@ class LLDB_API SBTarget { lldb::SBModule FindModule(const lldb::SBFileSpec &file_spec); + /// Find a module with the given module specification. + /// + /// \param[in] module_spec + /// A lldb::SBModuleSpec object that contains module specification. + /// + /// \return + /// A lldb::SBModule object that represents the found module, or an + /// invalid SBModule object if no module was found. + lldb::SBModule FindModule(const lldb::SBModuleSpec &module_spec); + /// Find compile units related to *this target and passed source /// file. /// diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 4087cdf2b42f7..0be60a8f3f96a 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -13,125 +13,284 @@ #ifndef LLDB_HOST_JSONTRANSPORT_H #define LLDB_HOST_JSONTRANSPORT_H +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "lldb/Utility/IOObject.h" +#include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" -#include +#include "llvm/Support/raw_ostream.h" +#include #include +#include +#include namespace lldb_private { -class TransportEOFError : public llvm::ErrorInfo { +class TransportUnhandledContentsError + : public llvm::ErrorInfo { public: static char ID; - TransportEOFError() = default; + explicit TransportUnhandledContentsError(std::string unhandled_contents); - void log(llvm::raw_ostream &OS) const override { - OS << "transport end of file reached"; - } - std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + + const std::string &getUnhandledContents() const { + return m_unhandled_contents; } + +private: + std::string m_unhandled_contents; }; -class TransportTimeoutError : public llvm::ErrorInfo { +/// A transport is responsible for maintaining the connection to a client +/// application, and reading/writing structured messages to it. +/// +/// Transports have limited thread safety requirements: +/// - Messages will not be sent concurrently. +/// - Messages MAY be sent while Run() is reading, or its callback is active. +template class Transport { public: - static char ID; + using Message = std::variant; - TransportTimeoutError() = default; + virtual ~Transport() = default; - void log(llvm::raw_ostream &OS) const override { - OS << "transport operation timed out"; - } - std::error_code convertToErrorCode() const override { - return std::make_error_code(std::errc::timed_out); - } -}; + /// Sends an event, a message that does not require a response. + virtual llvm::Error Send(const Evt &) = 0; + /// Sends a request, a message that expects a response. + virtual llvm::Error Send(const Req &) = 0; + /// Sends a response to a specific request. + virtual llvm::Error Send(const Resp &) = 0; -class TransportInvalidError : public llvm::ErrorInfo { -public: - static char ID; + /// Implemented to handle incoming messages. (See Run() below). + class MessageHandler { + public: + virtual ~MessageHandler() = default; + /// Called when an event is received. + virtual void Received(const Evt &) = 0; + /// Called when a request is received. + virtual void Received(const Req &) = 0; + /// Called when a response is received. + virtual void Received(const Resp &) = 0; - TransportInvalidError() = default; + /// Called when an error occurs while reading from the transport. + /// + /// NOTE: This does *NOT* indicate that a specific request failed, but that + /// there was an error in the underlying transport. + virtual void OnError(llvm::Error) = 0; - void log(llvm::raw_ostream &OS) const override { - OS << "transport IO object invalid"; - } - std::error_code convertToErrorCode() const override { - return std::make_error_code(std::errc::not_connected); + /// Called on EOF or client disconnect. + virtual void OnClosed() = 0; + }; + + using MessageHandlerSP = std::shared_ptr; + + /// RegisterMessageHandler registers the Transport with the given MainLoop and + /// handles any incoming messages using the given MessageHandler. + /// + /// If an unexpected error occurs, the MainLoop will be terminated and a log + /// message will include additional information about the termination reason. + virtual llvm::Expected + RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; + +protected: + template inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); } + virtual void Log(llvm::StringRef message) = 0; }; -/// A transport class that uses JSON for communication. -class JSONTransport { +/// A JSONTransport will encode and decode messages using JSON. +template +class JSONTransport : public Transport { public: - JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); - virtual ~JSONTransport() = default; - - /// Transport is not copyable. - /// @{ - JSONTransport(const JSONTransport &rhs) = delete; - void operator=(const JSONTransport &rhs) = delete; - /// @} - - /// Writes a message to the output stream. - template llvm::Error Write(const T &t) { - const std::string message = llvm::formatv("{0}", toJSON(t)).str(); - return WriteImpl(message); - } + using Transport::Transport; + using MessageHandler = typename Transport::MessageHandler; - /// Reads the next message from the input stream. - template - llvm::Expected Read(const std::chrono::microseconds &timeout) { - llvm::Expected message = ReadImpl(timeout); - if (!message) - return message.takeError(); - return llvm::json::parse(/*JSON=*/*message); + JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + : m_in(in), m_out(out) {} + + llvm::Error Send(const Evt &evt) override { return Write(evt); } + llvm::Error Send(const Req &req) override { return Write(req); } + llvm::Error Send(const Resp &resp) override { return Write(resp); } + + llvm::Expected + RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { + Status status; + MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( + m_in, + std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, + std::ref(handler)), + status); + if (status.Fail()) { + return status.takeError(); + } + return read_handle; } + /// Public for testing purposes, otherwise this should be an implementation + /// detail. + static constexpr size_t kReadBufferSize = 1024; + protected: - virtual void Log(llvm::StringRef message); + virtual llvm::Expected> Parse() = 0; + virtual std::string Encode(const llvm::json::Value &message) = 0; + llvm::Error Write(const llvm::json::Value &message) { + this->Logv("<-- {0}", message); + std::string output = Encode(message); + size_t bytes_written = output.size(); + return m_out->Write(output.data(), bytes_written).takeError(); + } + + llvm::SmallString m_buffer; - virtual llvm::Error WriteImpl(const std::string &message) = 0; - virtual llvm::Expected - ReadImpl(const std::chrono::microseconds &timeout) = 0; +private: + void OnRead(MainLoopBase &loop, MessageHandler &handler) { + char buf[kReadBufferSize]; + size_t num_bytes = sizeof(buf); + if (Status status = m_in->Read(buf, num_bytes); status.Fail()) { + handler.OnError(status.takeError()); + return; + } - lldb::IOObjectSP m_input; - lldb::IOObjectSP m_output; + if (num_bytes) + m_buffer.append(llvm::StringRef(buf, num_bytes)); + + // If the buffer has contents, try parsing any pending messages. + if (!m_buffer.empty()) { + llvm::Expected> raw_messages = Parse(); + if (llvm::Error error = raw_messages.takeError()) { + handler.OnError(std::move(error)); + return; + } + + for (const std::string &raw_message : *raw_messages) { + llvm::Expected::Message> message = + llvm::json::parse::Message>( + raw_message); + if (!message) { + handler.OnError(message.takeError()); + return; + } + + std::visit([&handler](auto &&msg) { handler.Received(msg); }, *message); + } + } + + // Check if we reached EOF. + if (num_bytes == 0) { + // EOF reached, but there may still be unhandled contents in the buffer. + if (!m_buffer.empty()) + handler.OnError(llvm::make_error( + std::string(m_buffer.str()))); + handler.OnClosed(); + } + } + + lldb::IOObjectSP m_in; + lldb::IOObjectSP m_out; }; /// A transport class for JSON with a HTTP header. -class HTTPDelimitedJSONTransport : public JSONTransport { +template +class HTTPDelimitedJSONTransport : public JSONTransport { public: - HTTPDelimitedJSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~HTTPDelimitedJSONTransport() = default; + using JSONTransport::JSONTransport; protected: - virtual llvm::Error WriteImpl(const std::string &message) override; - virtual llvm::Expected - ReadImpl(const std::chrono::microseconds &timeout) override; - - // FIXME: Support any header. - static constexpr llvm::StringLiteral kHeaderContentLength = - "Content-Length: "; - static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n\r\n"; + /// Encodes messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + std::string Encode(const llvm::json::Value &message) override { + std::string output; + std::string raw_message = llvm::formatv("{0}", message).str(); + llvm::raw_string_ostream OS(output); + OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' + << std::to_string(raw_message.size()) << kEndOfHeader << raw_message; + return output; + } + + /// Parses messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + llvm::Expected> Parse() override { + std::vector messages; + llvm::StringRef buffer = this->m_buffer; + while (buffer.contains(kEndOfHeader)) { + auto [headers, rest] = buffer.split(kEndOfHeader); + size_t content_length = 0; + // HTTP Headers are formatted like ` ':' []`. + for (const llvm::StringRef &header : + llvm::split(headers, kHeaderSeparator)) { + auto [key, value] = header.split(kHeaderFieldSeparator); + // 'Content-Length' is the only meaningful key at the moment. Others are + // ignored. + if (!key.equals_insensitive(kHeaderContentLength)) + continue; + + value = value.trim(); + if (!llvm::to_integer(value, content_length, 10)) { + // Clear the buffer to avoid re-parsing this malformed message. + this->m_buffer.clear(); + return llvm::createStringError(std::errc::invalid_argument, + "invalid content length: %s", + value.str().c_str()); + } + } + + // Check if we have enough data. + if (content_length > rest.size()) + break; + + llvm::StringRef body = rest.take_front(content_length); + buffer = rest.drop_front(content_length); + messages.emplace_back(body.str()); + this->Logv("--> {0}", body); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buffer.str(); + + return std::move(messages); + } + + static constexpr llvm::StringLiteral kHeaderContentLength = "Content-Length"; + static constexpr llvm::StringLiteral kHeaderFieldSeparator = ":"; + static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n"; + static constexpr llvm::StringLiteral kEndOfHeader = "\r\n\r\n"; }; /// A transport class for JSON RPC. -class JSONRPCTransport : public JSONTransport { +template +class JSONRPCTransport : public JSONTransport { public: - JSONRPCTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~JSONRPCTransport() = default; + using JSONTransport::JSONTransport; protected: - virtual llvm::Error WriteImpl(const std::string &message) override; - virtual llvm::Expected - ReadImpl(const std::chrono::microseconds &timeout) override; + std::string Encode(const llvm::json::Value &message) override { + return llvm::formatv("{0}{1}", message, kMessageSeparator).str(); + } + + llvm::Expected> Parse() override { + std::vector messages; + llvm::StringRef buf = this->m_buffer; + while (buf.contains(kMessageSeparator)) { + auto [raw_json, rest] = buf.split(kMessageSeparator); + buf = rest; + messages.emplace_back(raw_json.str()); + this->Logv("--> {0}", raw_json); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buf.str(); + + return messages; + } static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 49f9490221755..6e1ffcbe1f3e3 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace lldb_protocol::mcp { @@ -38,11 +39,24 @@ struct Request { /// The method's params. std::optional params; }; - llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); bool operator==(const Request &, const Request &); +enum ErrorCode : signed { + /// Invalid JSON was received by the server. An error occurred on the server + /// while parsing the JSON text. + eErrorCodeParseError = -32700, + /// The JSON sent is not a valid Request object. + eErrorCodeInvalidRequest = -32600, + /// The method does not exist / is not available. + eErrorCodeMethodNotFound = -32601, + /// Invalid method parameter(s). + eErrorCodeInvalidParams = -32602, + /// Internal JSON-RPC error. + eErrorCodeInternalError = -32603, +}; + struct Error { /// The error type that occurred. int64_t code = 0; @@ -52,9 +66,8 @@ struct Error { /// Additional information about the error. The value of this member is /// defined by the sender (e.g. detailed error information, nested errors /// etc.). - std::optional data; + std::optional data = std::nullopt; }; - llvm::json::Value toJSON(const Error &); bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); bool operator==(const Error &, const Error &); @@ -67,7 +80,6 @@ struct Response { /// response. std::variant result; }; - llvm::json::Value toJSON(const Response &); bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); bool operator==(const Response &, const Response &); @@ -79,7 +91,6 @@ struct Notification { /// The notification's params. std::optional params; }; - llvm::json::Value toJSON(const Notification &); bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); bool operator==(const Notification &, const Notification &); @@ -90,45 +101,9 @@ using Message = std::variant; // not force it to be checked early here. static_assert(std::is_convertible_v, "Message is not convertible to itself"); - bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); -struct ToolCapability { - /// Whether this server supports notifications for changes to the tool list. - bool listChanged = false; -}; - -llvm::json::Value toJSON(const ToolCapability &); -bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); - -struct ResourceCapability { - /// Whether this server supports notifications for changes to the resources - /// list. - bool listChanged = false; - - /// Whether subscriptions are supported. - bool subscribe = false; -}; - -llvm::json::Value toJSON(const ResourceCapability &); -bool fromJSON(const llvm::json::Value &, ResourceCapability &, - llvm::json::Path); - -/// Capabilities that a server may support. Known capabilities are defined here, -/// in this schema, but this is not a closed set: any server can define its own, -/// additional capabilities. -struct Capabilities { - /// Tool capabilities of the server. - ToolCapability tools; - - /// Resource capabilities of the server. - ResourceCapability resources; -}; - -llvm::json::Value toJSON(const Capabilities &); -bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); - /// A known resource that the server is capable of reading. struct Resource { /// The URI of this resource. @@ -138,17 +113,25 @@ struct Resource { std::string name; /// A description of what this resource represents. - std::string description; + std::string description = ""; /// The MIME type of this resource, if known. - std::string mimeType; + std::string mimeType = ""; }; llvm::json::Value toJSON(const Resource &); bool fromJSON(const llvm::json::Value &, Resource &, llvm::json::Path); +/// The server’s response to a resources/list request from the client. +struct ListResourcesResult { + std::vector resources; +}; +llvm::json::Value toJSON(const ListResourcesResult &); +bool fromJSON(const llvm::json::Value &, ListResourcesResult &, + llvm::json::Path); + /// The contents of a specific resource or sub-resource. -struct ResourceContents { +struct TextResourceContents { /// The URI of this resource. std::string uri; @@ -160,34 +143,37 @@ struct ResourceContents { std::string mimeType; }; -llvm::json::Value toJSON(const ResourceContents &); -bool fromJSON(const llvm::json::Value &, ResourceContents &, llvm::json::Path); +llvm::json::Value toJSON(const TextResourceContents &); +bool fromJSON(const llvm::json::Value &, TextResourceContents &, + llvm::json::Path); -/// The server's response to a resources/read request from the client. -struct ResourceResult { - std::vector contents; +/// Sent from the client to the server, to read a specific resource URI. +struct ReadResourceParams { + /// The URI of the resource to read. The URI can use any protocol; it is up to + /// the server how to interpret it. + std::string uri; }; +llvm::json::Value toJSON(const ReadResourceParams &); +bool fromJSON(const llvm::json::Value &, ReadResourceParams &, + llvm::json::Path); -llvm::json::Value toJSON(const ResourceResult &); -bool fromJSON(const llvm::json::Value &, ResourceResult &, llvm::json::Path); +/// The server's response to a resources/read request from the client. +struct ReadResourceResult { + std::vector contents; +}; +llvm::json::Value toJSON(const ReadResourceResult &); +bool fromJSON(const llvm::json::Value &, ReadResourceResult &, + llvm::json::Path); /// Text provided to or from an LLM. struct TextContent { /// The text content of the message. std::string text; }; - llvm::json::Value toJSON(const TextContent &); bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); -struct TextResult { - std::vector content; - bool isError = false; -}; - -llvm::json::Value toJSON(const TextResult &); -bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); - +/// Definition for a tool the client can call. struct ToolDefinition { /// Unique identifier for the tool. std::string name; @@ -198,12 +184,144 @@ struct ToolDefinition { // JSON Schema for the tool's parameters. std::optional inputSchema; }; - llvm::json::Value toJSON(const ToolDefinition &); bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); using ToolArguments = std::variant; +/// Describes the name and version of an MCP implementation, with an optional +/// title for UI representation. +struct Implementation { + /// Intended for programmatic or logical use, but used as a display name in + /// past specs or fallback (if title isn’t present). + std::string name; + + std::string version; + + /// Intended for UI and end-user contexts — optimized to be human-readable and + /// easily understood, even by those unfamiliar with domain-specific + /// terminology. + /// + /// If not provided, the name should be used for display (except for Tool, + /// where annotations.title should be given precedence over using name, if + /// present). + std::string title = ""; +}; +llvm::json::Value toJSON(const Implementation &); +bool fromJSON(const llvm::json::Value &, Implementation &, llvm::json::Path); + +/// Capabilities a client may support. Known capabilities are defined here, in +/// this schema, but this is not a closed set: any client can define its own, +/// additional capabilities. +struct ClientCapabilities {}; +llvm::json::Value toJSON(const ClientCapabilities &); +bool fromJSON(const llvm::json::Value &, ClientCapabilities &, + llvm::json::Path); + +/// Capabilities that a server may support. Known capabilities are defined here, +/// in this schema, but this is not a closed set: any server can define its own, +/// additional capabilities. +struct ServerCapabilities { + bool supportsToolsList = false; + bool supportsResourcesList = false; + bool supportsResourcesSubscribe = false; + + /// Utilities. + bool supportsCompletions = false; + bool supportsLogging = false; +}; +llvm::json::Value toJSON(const ServerCapabilities &); +bool fromJSON(const llvm::json::Value &, ServerCapabilities &, + llvm::json::Path); + +/// Initialization + +/// This request is sent from the client to the server when it first connects, +/// asking it to begin initialization. +struct InitializeParams { + /// The latest version of the Model Context Protocol that the client supports. + /// The client MAY decide to support older versions as well. + std::string protocolVersion; + + ClientCapabilities capabilities; + + Implementation clientInfo; +}; +llvm::json::Value toJSON(const InitializeParams &); +bool fromJSON(const llvm::json::Value &, InitializeParams &, llvm::json::Path); + +/// After receiving an initialize request from the client, the server sends this +/// response. +struct InitializeResult { + /// The version of the Model Context Protocol that the server wants to use. + /// This may not match the version that the client requested. If the client + /// cannot support this version, it MUST disconnect. + std::string protocolVersion; + + ServerCapabilities capabilities; + Implementation serverInfo; + + /// Instructions describing how to use the server and its features. + /// + /// This can be used by clients to improve the LLM's understanding of + /// available tools, resources, etc. It can be thought of like a "hint" to the + /// model. For example, this information MAY be added to the system prompt. + std::string instructions = ""; +}; +llvm::json::Value toJSON(const InitializeResult &); +bool fromJSON(const llvm::json::Value &, InitializeResult &, llvm::json::Path); + +/// Special case parameter or result that has no value. +using Void = std::monostate; +llvm::json::Value toJSON(const Void &); +bool fromJSON(const llvm::json::Value &, Void &, llvm::json::Path); + +/// The server's response to a `tools/list` request from the client. +struct ListToolsResult { + std::vector tools; +}; +llvm::json::Value toJSON(const ListToolsResult &); +bool fromJSON(const llvm::json::Value &, ListToolsResult &, llvm::json::Path); + +/// Supported content types, currently only TextContent, but the spec includes +/// additional content types. +using ContentBlock = TextContent; + +/// Used by the client to invoke a tool provided by the server. +struct CallToolParams { + std::string name; + std::optional arguments; +}; +llvm::json::Value toJSON(const CallToolParams &); +bool fromJSON(const llvm::json::Value &, CallToolParams &, llvm::json::Path); + +/// The server’s response to a tool call. +struct CallToolResult { + /// A list of content objects that represent the unstructured result of the + /// tool call. + std::vector content; + + /// Whether the tool call ended in an error. + /// + /// If not set, this is assumed to be false (the call was successful). + /// + /// Any errors that originate from the tool SHOULD be reported inside the + /// result object, with `isError` set to true, not as an MCP protocol-level + /// error response. Otherwise, the LLM would not be able to see that an error + /// occurred and self-correct. + /// + /// However, any errors in finding the tool, an error indicating that the + /// server does not support tool calls, or any other exceptional conditions, + /// should be reported as an MCP error response. + bool isError = false; + + /// An optional JSON object that represents the structured result of the tool + /// call. + std::optional structuredContent = std::nullopt; +}; +llvm::json::Value toJSON(const CallToolResult &); +bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); + } // namespace lldb_protocol::mcp #endif diff --git a/lldb/include/lldb/Protocol/MCP/Resource.h b/lldb/include/lldb/Protocol/MCP/Resource.h index 4835d340cd4c6..158cffc71ea10 100644 --- a/lldb/include/lldb/Protocol/MCP/Resource.h +++ b/lldb/include/lldb/Protocol/MCP/Resource.h @@ -20,7 +20,7 @@ class ResourceProvider { virtual ~ResourceProvider() = default; virtual std::vector GetResources() const = 0; - virtual llvm::Expected + virtual llvm::Expected ReadResource(llvm::StringRef uri) const = 0; }; diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 2ac05880de86b..aa5714e45755e 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,6 +9,8 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Tool.h" @@ -18,31 +20,57 @@ namespace lldb_protocol::mcp { -class Server { +class MCPTransport + : public lldb_private::JSONRPCTransport { public: - Server(std::string name, std::string version); - virtual ~Server() = default; + using LogCallback = std::function; + + MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out, + std::string client_name, LogCallback log_callback = {}) + : JSONRPCTransport(in, out), m_client_name(std::move(client_name)), + m_log_callback(log_callback) {} + virtual ~MCPTransport() = default; + + void Log(llvm::StringRef message) override { + if (m_log_callback) + m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str()); + } + +private: + std::string m_client_name; + LogCallback m_log_callback; +}; + +class Server : public MCPTransport::MessageHandler { +public: + Server(std::string name, std::string version, + std::unique_ptr transport_up, + lldb_private::MainLoop &loop); + ~Server() = default; + + using NotificationHandler = std::function; void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + + llvm::Error Run(); protected: - virtual Capabilities GetCapabilities() = 0; + ServerCapabilities GetCapabilities(); using RequestHandler = std::function(const Request &)>; - using NotificationHandler = std::function; void AddRequestHandlers(); void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); llvm::Expected> HandleData(llvm::StringRef data); - llvm::Expected Handle(Request request); - void Handle(Notification notification); + llvm::Expected Handle(const Request &request); + void Handle(const Notification ¬ification); llvm::Expected InitializeHandler(const Request &); @@ -52,12 +80,21 @@ class Server { llvm::Expected ResourcesListHandler(const Request &); llvm::Expected ResourcesReadHandler(const Request &); - std::mutex m_mutex; + void Received(const Request &) override; + void Received(const Response &) override; + void Received(const Notification &) override; + void OnError(llvm::Error) override; + void OnClosed() override; + + void TerminateLoop(); private: const std::string m_name; const std::string m_version; + std::unique_ptr m_transport_up; + lldb_private::MainLoop &m_loop; + llvm::StringMap> m_tools; std::vector> m_resource_providers; diff --git a/lldb/include/lldb/Protocol/MCP/Tool.h b/lldb/include/lldb/Protocol/MCP/Tool.h index 96669d1357166..6c9f05161f8e7 100644 --- a/lldb/include/lldb/Protocol/MCP/Tool.h +++ b/lldb/include/lldb/Protocol/MCP/Tool.h @@ -10,6 +10,7 @@ #define LLDB_PROTOCOL_MCP_TOOL_H #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include @@ -20,7 +21,7 @@ class Tool { Tool(std::string name, std::string description); virtual ~Tool() = default; - virtual llvm::Expected + virtual llvm::Expected Call(const lldb_protocol::mcp::ToolArguments &args) = 0; virtual std::optional GetSchema() const { diff --git a/lldb/include/lldb/Symbol/Symbol.h b/lldb/include/lldb/Symbol/Symbol.h index 2d97a64d52b31..11d91418d3971 100644 --- a/lldb/include/lldb/Symbol/Symbol.h +++ b/lldb/include/lldb/Symbol/Symbol.h @@ -15,6 +15,7 @@ #include "lldb/Symbol/SymbolContextScope.h" #include "lldb/Utility/Stream.h" #include "lldb/Utility/UserID.h" +#include "lldb/lldb-enumerations.h" #include "lldb/lldb-private.h" #include "llvm/Support/JSON.h" @@ -301,6 +302,10 @@ class Symbol : public SymbolContextScope { bool operator==(const Symbol &rhs) const; + static const char *GetTypeAsString(lldb::SymbolType symbol_type); + + static lldb::SymbolType GetTypeFromString(const char *str); + protected: // This is the internal guts of ResolveReExportedSymbol, it assumes // reexport_name is not null, and that module_spec is valid. We track the diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py index 939be9941a49d..0608ac3fd83be 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py @@ -12,15 +12,91 @@ import sys import threading import time -from typing import Any, Optional, Union, BinaryIO, TextIO +from typing import ( + Any, + Optional, + Dict, + cast, + List, + Callable, + IO, + Union, + BinaryIO, + TextIO, + TypedDict, + Literal, +) ## DAP type references -Event = dict[str, Any] -Request = dict[str, Any] -Response = dict[str, Any] + + +class Event(TypedDict): + type: Literal["event"] + seq: int + event: str + body: Any + + +class Request(TypedDict, total=False): + type: Literal["request"] + seq: int + command: str + arguments: Any + + +class Response(TypedDict): + type: Literal["response"] + seq: int + request_seq: int + success: bool + command: str + message: Optional[str] + body: Any + + ProtocolMessage = Union[Event, Request, Response] +class Source(TypedDict, total=False): + name: str + path: str + sourceReference: int + + @staticmethod + def build( + *, + name: Optional[str] = None, + path: Optional[str] = None, + source_reference: Optional[int] = None, + ) -> "Source": + """Builds a source from the given name, path or source_reference.""" + if not name and not path and not source_reference: + raise ValueError( + "Source.build requires either name, path, or source_reference" + ) + + s = Source() + if name: + s["name"] = name + if path: + if not name: + s["name"] = os.path.basename(path) + s["path"] = path + if source_reference is not None: + s["sourceReference"] = source_reference + return s + + +class Breakpoint(TypedDict, total=False): + id: int + verified: bool + source: Source + + @staticmethod + def is_verified(src: "Breakpoint") -> bool: + return src.get("verified", False) + + def dump_memory(base_addr, data, num_per_line, outfile): data_len = len(data) hex_string = binascii.hexlify(data) @@ -58,7 +134,9 @@ def dump_memory(base_addr, data, num_per_line, outfile): outfile.write("\n") -def read_packet(f, verbose=False, trace_file=None): +def read_packet( + f: IO[bytes], trace_file: Optional[IO[str]] = None +) -> Optional[ProtocolMessage]: """Decode a JSON packet that starts with the content length and is followed by the JSON bytes from a file 'f'. Returns None on EOF. """ @@ -70,19 +148,13 @@ def read_packet(f, verbose=False, trace_file=None): prefix = "Content-Length: " if line.startswith(prefix): # Decode length of JSON bytes - if verbose: - print('content: "%s"' % (line)) length = int(line[len(prefix) :]) - if verbose: - print('length: "%u"' % (length)) # Skip empty line - line = f.readline() - if verbose: - print('empty: "%s"' % (line)) + separator = f.readline().decode() + if separator != "": + Exception("malformed DAP content header, unexpected line: " + separator) # Read JSON bytes - json_str = f.read(length) - if verbose: - print('json: "%s"' % (json_str)) + json_str = f.read(length).decode() if trace_file: trace_file.write("from adapter:\n%s\n" % (json_str)) # Decode the JSON bytes into a python dictionary @@ -95,7 +167,7 @@ def packet_type_is(packet, packet_type): return "type" in packet and packet["type"] == packet_type -def dump_dap_log(log_file): +def dump_dap_log(log_file: Optional[str]) -> None: print("========= DEBUG ADAPTER PROTOCOL LOGS =========", file=sys.stderr) if log_file is None: print("no log file available", file=sys.stderr) @@ -105,58 +177,6 @@ def dump_dap_log(log_file): print("========= END =========", file=sys.stderr) -class Source(object): - def __init__( - self, - path: Optional[str] = None, - source_reference: Optional[int] = None, - raw_dict: Optional[dict[str, Any]] = None, - ): - self._name = None - self._path = None - self._source_reference = None - self._raw_dict = None - - if path is not None: - self._name = os.path.basename(path) - self._path = path - elif source_reference is not None: - self._source_reference = source_reference - elif raw_dict is not None: - self._raw_dict = raw_dict - else: - raise ValueError("Either path or source_reference must be provided") - - def __str__(self): - return f"Source(name={self.name}, path={self.path}), source_reference={self.source_reference})" - - def as_dict(self): - if self._raw_dict is not None: - return self._raw_dict - - source_dict = {} - if self._name is not None: - source_dict["name"] = self._name - if self._path is not None: - source_dict["path"] = self._path - if self._source_reference is not None: - source_dict["sourceReference"] = self._source_reference - return source_dict - - -class Breakpoint(object): - def __init__(self, obj): - self._breakpoint = obj - - def is_verified(self): - """Check if the breakpoint is verified.""" - return self._breakpoint.get("verified", False) - - def source(self): - """Get the source of the breakpoint.""" - return self._breakpoint.get("source", {}) - - class NotSupportedError(KeyError): """Raised if a feature is not supported due to its capabilities.""" @@ -174,26 +194,42 @@ def __init__( self.log_file = log_file self.send = send self.recv = recv - self.recv_packets: list[Optional[ProtocolMessage]] = [] - self.recv_condition = threading.Condition() - self.recv_thread = threading.Thread(target=self._read_packet_thread) - self.process_event_body = None - self.exit_status: Optional[int] = None - self.capabilities: dict[str, Any] = {} - self.progress_events: list[Event] = [] - self.reverse_requests = [] - self.sequence = 1 - self.threads = None - self.thread_stop_reasons = {} - self.recv_thread.start() - self.output_condition = threading.Condition() - self.output: dict[str, list[str]] = {} - self.configuration_done_sent = False - self.initialized = False - self.frame_scopes = {} + + # Packets that have been received and processed but have not yet been + # requested by a test case. + self._pending_packets: List[Optional[ProtocolMessage]] = [] + # Received packets that have not yet been processed. + self._recv_packets: List[Optional[ProtocolMessage]] = [] + # Used as a mutex for _recv_packets and for notify when _recv_packets + # changes. + self._recv_condition = threading.Condition() + self._recv_thread = threading.Thread(target=self._read_packet_thread) + + # session state self.init_commands = init_commands + self.exit_status: Optional[int] = None + self.capabilities: Dict = {} + self.initialized: bool = False + self.configuration_done_sent: bool = False + self.process_event_body: Optional[Dict] = None + self.terminated: bool = False + self.events: List[Event] = [] + self.progress_events: List[Event] = [] + self.reverse_requests: List[Request] = [] + self.module_events: List[Dict] = [] + self.sequence: int = 1 + self.output: Dict[str, str] = {} + + # debuggee state + self.threads: Optional[dict] = None + self.thread_stop_reasons: Dict[str, Any] = {} + self.frame_scopes: Dict[str, Any] = {} + # keyed by breakpoint id self.resolved_breakpoints: dict[str, Breakpoint] = {} + # trigger enqueue thread + self._recv_thread.start() + @classmethod def encode_content(cls, s: str) -> bytes: return ("Content-Length: %u\r\n\r\n%s" % (len(s), s)).encode("utf-8") @@ -210,267 +246,324 @@ def validate_response(cls, command, response): ) def _read_packet_thread(self): - done = False try: - while not done: + while True: packet = read_packet(self.recv, trace_file=self.trace_file) # `packet` will be `None` on EOF. We want to pass it down to # handle_recv_packet anyway so the main thread can handle unexpected # termination of lldb-dap and stop waiting for new packets. - done = not self._handle_recv_packet(packet) + if not self._handle_recv_packet(packet): + break finally: dump_dap_log(self.log_file) - def get_modules(self, startModule: int = 0, moduleCount: int = 0): - module_list = self.request_modules(startModule, moduleCount)["body"]["modules"] + def get_modules( + self, start_module: Optional[int] = None, module_count: Optional[int] = None + ) -> Dict: + resp = self.request_modules(start_module, module_count) + if not resp["success"]: + raise ValueError(f"request_modules failed: {resp!r}") modules = {} + module_list = resp["body"]["modules"] for module in module_list: modules[module["name"]] = module return modules - def get_output(self, category, timeout=0.0, clear=True): - self.output_condition.acquire() - output = None + def get_output(self, category: str, clear=True) -> str: + output = "" if category in self.output: - output = self.output[category] + output = self.output.get(category, "") if clear: del self.output[category] - elif timeout != 0.0: - self.output_condition.wait(timeout) - if category in self.output: - output = self.output[category] - if clear: - del self.output[category] - self.output_condition.release() return output - def collect_output(self, category, timeout_secs, pattern, clear=True): - end_time = time.time() + timeout_secs - collected_output = "" - while end_time > time.time(): - output = self.get_output(category, timeout=0.25, clear=clear) - if output: - collected_output += output - if pattern is not None and pattern in output: - break - return collected_output if collected_output else None + def collect_output( + self, + category: str, + timeout: float, + pattern: Optional[str] = None, + clear=True, + ) -> str: + """Collect output from 'output' events. + Args: + category: The category to collect. + timeout: The max duration for collecting output. + pattern: + Optional, if set, return once this pattern is detected in the + collected output. + Returns: + The collected output. + """ + deadline = time.monotonic() + timeout + output = self.get_output(category, clear) + while deadline >= time.monotonic() and ( + pattern is None or pattern not in output + ): + event = self.wait_for_event(["output"], timeout=deadline - time.monotonic()) + if not event: # Timeout or EOF + break + output += self.get_output(category, clear=clear) + return output def _enqueue_recv_packet(self, packet: Optional[ProtocolMessage]): - self.recv_condition.acquire() - self.recv_packets.append(packet) - self.recv_condition.notify() - self.recv_condition.release() + with self.recv_condition: + self.recv_packets.append(packet) + self.recv_condition.notify() def _handle_recv_packet(self, packet: Optional[ProtocolMessage]) -> bool: - """Called by the read thread that is waiting for all incoming packets - to store the incoming packet in "self.recv_packets" in a thread safe - way. This function will then signal the "self.recv_condition" to - indicate a new packet is available. Returns True if the caller - should keep calling this function for more packets. + """Handles an incoming packet. + + Called by the read thread that is waiting for all incoming packets + to store the incoming packet in "self._recv_packets" in a thread safe + way. This function will then signal the "self._recv_condition" to + indicate a new packet is available. + + Args: + packet: A new packet to store. + + Returns: + True if the caller should keep calling this function for more + packets. """ - # If EOF, notify the read thread by enqueuing a None. - if not packet: - self._enqueue_recv_packet(None) - return False - - # Check the packet to see if is an event packet - keepGoing = True - packet_type = packet["type"] - if packet_type == "event": - event = packet["event"] - body = None - if "body" in packet: - body = packet["body"] - # Handle the event packet and cache information from these packets - # as they come in - if event == "output": - # Store any output we receive so clients can retrieve it later. - category = body["category"] - output = body["output"] - self.output_condition.acquire() - if category in self.output: - self.output[category] += output - else: - self.output[category] = output - self.output_condition.notify() - self.output_condition.release() - # no need to add 'output' event packets to our packets list - return keepGoing - elif event == "initialized": - self.initialized = True - elif event == "process": - # When a new process is attached or launched, remember the - # details that are available in the body of the event - self.process_event_body = body - elif event == "exited": - # Process exited, mark the status to indicate the process is not - # alive. - self.exit_status = body["exitCode"] - elif event == "continued": - # When the process continues, clear the known threads and - # thread_stop_reasons. - all_threads_continued = body.get("allThreadsContinued", True) - tid = body["threadId"] - if tid in self.thread_stop_reasons: - del self.thread_stop_reasons[tid] - self._process_continued(all_threads_continued) - elif event == "stopped": - # Each thread that stops with a reason will send a - # 'stopped' event. We need to remember the thread stop - # reasons since the 'threads' command doesn't return - # that information. - self._process_stopped() - tid = body["threadId"] - self.thread_stop_reasons[tid] = body - elif event.startswith("progress"): - # Progress events come in as 'progressStart', 'progressUpdate', - # and 'progressEnd' events. Keep these around in case test - # cases want to verify them. - self.progress_events.append(packet) - elif event == "breakpoint": - # Breakpoint events are sent when a breakpoint is resolved - self._update_verified_breakpoints([body["breakpoint"]]) - elif event == "capabilities": - # Update the capabilities with new ones from the event. - self.capabilities.update(body["capabilities"]) - - elif packet_type == "response": - if packet["command"] == "disconnect": - keepGoing = False - self._enqueue_recv_packet(packet) - return keepGoing + with self._recv_condition: + self._recv_packets.append(packet) + self._recv_condition.notify() + # packet is None on EOF + return packet is not None and not ( + packet["type"] == "response" and packet["command"] == "disconnect" + ) + + def _recv_packet( + self, + *, + predicate: Optional[Callable[[ProtocolMessage], bool]] = None, + timeout: Optional[float] = None, + ) -> Optional[ProtocolMessage]: + """Processes received packets from the adapter. + Updates the DebugCommunication stateful properties based on the received + packets in the order they are received. + NOTE: The only time the session state properties should be updated is + during this call to ensure consistency during tests. + Args: + predicate: + Optional, if specified, returns the first packet that matches + the given predicate. + timeout: + Optional, if specified, processes packets until either the + timeout occurs or the predicate matches a packet, whichever + occurs first. + Returns: + The first matching packet for the given predicate, if specified, + otherwise None. + """ + assert ( + threading.current_thread != self._recv_thread + ), "Must not be called from the _recv_thread" + + def process_until_match(): + self._process_recv_packets() + for i, packet in enumerate(self._pending_packets): + if packet is None: + # We need to return a truthy value to break out of the + # wait_for, use `EOFError` as an indicator of EOF. + return EOFError() + if predicate and predicate(packet): + self._pending_packets.pop(i) + return packet + + with self._recv_condition: + packet = self._recv_condition.wait_for(process_until_match, timeout) + return None if isinstance(packet, EOFError) else packet + + def _process_recv_packets(self) -> None: + """Process received packets, updating the session state.""" + with self._recv_condition: + for packet in self._recv_packets: + # Handle events that may modify any stateful properties of + # the DAP session. + if packet and packet["type"] == "event": + self._handle_event(packet) + elif packet and packet["type"] == "request": + # Handle reverse requests and keep processing. + self._handle_reverse_request(packet) + # Move the packet to the pending queue. + self._pending_packets.append(packet) + self._recv_packets.clear() + + def _handle_event(self, packet: Event) -> None: + """Handle any events that modify debug session state we track.""" + event = packet["event"] + body: Optional[Dict] = packet.get("body", None) + + if event == "output" and body: + # Store any output we receive so clients can retrieve it later. + category = body["category"] + output = body["output"] + if category in self.output: + self.output[category] += output + else: + self.output[category] = output + elif event == "initialized": + self.initialized = True + elif event == "process": + # When a new process is attached or launched, remember the + # details that are available in the body of the event + self.process_event_body = body + elif event == "exited" and body: + # Process exited, mark the status to indicate the process is not + # alive. + self.exit_status = body["exitCode"] + elif event == "continued" and body: + # When the process continues, clear the known threads and + # thread_stop_reasons. + all_threads_continued = body.get("allThreadsContinued", True) + tid = body["threadId"] + if tid in self.thread_stop_reasons: + del self.thread_stop_reasons[tid] + self._process_continued(all_threads_continued) + elif event == "stopped" and body: + # Each thread that stops with a reason will send a + # 'stopped' event. We need to remember the thread stop + # reasons since the 'threads' command doesn't return + # that information. + self._process_stopped() + tid = body["threadId"] + self.thread_stop_reasons[tid] = body + elif event.startswith("progress"): + # Progress events come in as 'progressStart', 'progressUpdate', + # and 'progressEnd' events. Keep these around in case test + # cases want to verify them. + self.progress_events.append(packet) + elif event == "breakpoint" and body: + # Breakpoint events are sent when a breakpoint is resolved + self._update_verified_breakpoints([body["breakpoint"]]) + elif event == "capabilities" and body: + # Update the capabilities with new ones from the event. + self.capabilities.update(body["capabilities"]) + + def _handle_reverse_request(self, request: Request) -> None: + if request in self.reverse_requests: + return + self.reverse_requests.append(request) + arguments = request.get("arguments") + if request["command"] == "runInTerminal" and arguments is not None: + in_shell = arguments.get("argsCanBeInterpretedByShell", False) + print("spawning...", arguments["args"]) + proc = subprocess.Popen( + arguments["args"], + env=arguments.get("env", {}), + cwd=arguments.get("cwd", None), + stdin=subprocess.DEVNULL, + stdout=sys.stderr, + stderr=sys.stderr, + shell=in_shell, + ) + body = {} + if in_shell: + body["shellProcessId"] = proc.pid + else: + body["processId"] = proc.pid + self.send_packet( + { + "type": "response", + "seq": 0, + "request_seq": request["seq"], + "success": True, + "command": "runInTerminal", + "body": body, + } + ) + elif request["command"] == "startDebugging": + self.send_packet( + { + "type": "response", + "seq": 0, + "request_seq": request["seq"], + "success": True, + "message": None, + "command": "startDebugging", + "body": {}, + } + ) + else: + desc = 'unknown reverse request "%s"' % (request["command"]) + raise ValueError(desc) def _process_continued(self, all_threads_continued: bool): self.frame_scopes = {} if all_threads_continued: self.thread_stop_reasons = {} - def _update_verified_breakpoints(self, breakpoints: list[Event]): - for breakpoint in breakpoints: - if "id" in breakpoint: - self.resolved_breakpoints[str(breakpoint["id"])] = Breakpoint( - breakpoint - ) + def _update_verified_breakpoints(self, breakpoints: list[Breakpoint]): + for bp in breakpoints: + # If no id is set, we cannot correlate the given breakpoint across + # requests, ignore it. + if "id" not in bp: + continue - def send_packet(self, command_dict: Request, set_sequence=True): - """Take the "command_dict" python dictionary and encode it as a JSON - string and send the contents as a packet to the VSCode debug - adapter""" - # Set the sequence ID for this command automatically - if set_sequence: - command_dict["seq"] = self.sequence + self.resolved_breakpoints[str(bp["id"])] = bp + + def send_packet(self, packet: ProtocolMessage) -> int: + """Takes a dictionary representation of a DAP request and send the request to the debug adapter. + + Returns the seq number of the request. + """ + # Set the seq for requests. + if packet["type"] == "request": + packet["seq"] = self.sequence self.sequence += 1 + else: + packet["seq"] = 0 + # Encode our command dictionary as a JSON string - json_str = json.dumps(command_dict, separators=(",", ":")) + json_str = json.dumps(packet, separators=(",", ":")) + if self.trace_file: self.trace_file.write("to adapter:\n%s\n" % (json_str)) + length = len(json_str) if length > 0: # Send the encoded JSON packet and flush the 'send' file self.send.write(self.encode_content(json_str)) self.send.flush() - def recv_packet( - self, - filter_type: Optional[str] = None, - filter_event: Optional[Union[str, list[str]]] = None, - timeout: Optional[float] = None, - ) -> Optional[ProtocolMessage]: - """Get a JSON packet from the VSCode debug adapter. This function - assumes a thread that reads packets is running and will deliver - any received packets by calling handle_recv_packet(...). This - function will wait for the packet to arrive and return it when - it does.""" - while True: - try: - self.recv_condition.acquire() - packet = None - while True: - for i, curr_packet in enumerate(self.recv_packets): - if not curr_packet: - raise EOFError - packet_type = curr_packet["type"] - if filter_type is None or packet_type in filter_type: - if filter_event is None or ( - packet_type == "event" - and curr_packet["event"] in filter_event - ): - packet = self.recv_packets.pop(i) - break - if packet: - break - # Sleep until packet is received - len_before = len(self.recv_packets) - self.recv_condition.wait(timeout) - len_after = len(self.recv_packets) - if len_before == len_after: - return None # Timed out - return packet - except EOFError: - return None - finally: - self.recv_condition.release() - - def send_recv(self, command): + return packet["seq"] + + def _send_recv(self, request: Request) -> Optional[Response]: """Send a command python dictionary as JSON and receive the JSON response. Validates that the response is the correct sequence and command in the reply. Any events that are received are added to the events list in this object""" - self.send_packet(command) - done = False - while not done: - response_or_request = self.recv_packet(filter_type=["response", "request"]) - if response_or_request is None: - desc = 'no response for "%s"' % (command["command"]) - raise ValueError(desc) - if response_or_request["type"] == "response": - self.validate_response(command, response_or_request) - return response_or_request - else: - self.reverse_requests.append(response_or_request) - if response_or_request["command"] == "runInTerminal": - subprocess.Popen( - response_or_request["arguments"].get("args"), - env=response_or_request["arguments"].get("env", {}), - ) - self.send_packet( - { - "type": "response", - "request_seq": response_or_request["seq"], - "success": True, - "command": "runInTerminal", - "body": {}, - }, - ) - elif response_or_request["command"] == "startDebugging": - self.send_packet( - { - "type": "response", - "request_seq": response_or_request["seq"], - "success": True, - "command": "startDebugging", - "body": {}, - }, - ) - else: - desc = 'unknown reverse request "%s"' % ( - response_or_request["command"] - ) - raise ValueError(desc) + seq = self.send_packet(request) + response = self.receive_response(seq) + if response is None: + raise ValueError(f"no response for {request!r}") + self.validate_response(request, response) + return response - return None + def receive_response(self, seq: int) -> Optional[Response]: + """Waits for a response with the associated request_sec.""" + + def predicate(p: ProtocolMessage): + return p["type"] == "response" and p["request_seq"] == seq + + return cast(Optional[Response], self._recv_packet(predicate=predicate)) def wait_for_event( - self, filter: Union[str, list[str]], timeout: Optional[float] = None + self, filter: List[str] = [], timeout: Optional[float] = None ) -> Optional[Event]: """Wait for the first event that matches the filter.""" - return self.recv_packet( - filter_type="event", filter_event=filter, timeout=timeout + + def predicate(p: ProtocolMessage): + return p["type"] == "event" and p["event"] in filter + + return cast( + Optional[Event], self._recv_packet(predicate=predicate, timeout=timeout) ) def wait_for_stopped( self, timeout: Optional[float] = None - ) -> Optional[list[Event]]: + ) -> Optional[List[Event]]: stopped_events = [] stopped_event = self.wait_for_event( filter=["stopped", "exited"], timeout=timeout @@ -491,7 +584,7 @@ def wait_for_stopped( def wait_for_breakpoint_events(self, timeout: Optional[float] = None): breakpoint_events: list[Event] = [] while True: - event = self.wait_for_event("breakpoint", timeout=timeout) + event = self.wait_for_event(["breakpoint"], timeout=timeout) if not event: break breakpoint_events.append(event) @@ -502,7 +595,7 @@ def wait_for_breakpoints_to_be_verified( ): """Wait for all breakpoints to be verified. Return all unverified breakpoints.""" while any(id not in self.resolved_breakpoints for id in breakpoint_ids): - breakpoint_event = self.wait_for_event("breakpoint", timeout=timeout) + breakpoint_event = self.wait_for_event(["breakpoint"], timeout=timeout) if breakpoint_event is None: break @@ -511,18 +604,18 @@ def wait_for_breakpoints_to_be_verified( for id in breakpoint_ids if ( id not in self.resolved_breakpoints - or not self.resolved_breakpoints[id].is_verified() + or not Breakpoint.is_verified(self.resolved_breakpoints[id]) ) ] def wait_for_exited(self, timeout: Optional[float] = None): - event_dict = self.wait_for_event("exited", timeout=timeout) + event_dict = self.wait_for_event(["exited"], timeout=timeout) if event_dict is None: raise ValueError("didn't get exited event") return event_dict def wait_for_terminated(self, timeout: Optional[float] = None): - event_dict = self.wait_for_event("terminated", timeout) + event_dict = self.wait_for_event(["terminated"], timeout) if event_dict is None: raise ValueError("didn't get terminated event") return event_dict @@ -733,7 +826,7 @@ def request_attach( if gdbRemoteHostname is not None: args_dict["gdb-remote-hostname"] = gdbRemoteHostname command_dict = {"command": "attach", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_breakpointLocations( self, file_path, line, end_line=None, column=None, end_column=None @@ -755,7 +848,7 @@ def request_breakpointLocations( "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_configurationDone(self): command_dict = { @@ -763,7 +856,7 @@ def request_configurationDone(self): "type": "request", "arguments": {}, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response: self.configuration_done_sent = True self.request_threads() @@ -792,7 +885,7 @@ def request_continue(self, threadId=None, singleThread=False): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response["success"]: self._process_continued(response["body"]["allThreadsContinued"]) # Caller must still call wait_for_stopped. @@ -809,7 +902,7 @@ def request_restart(self, restartArguments=None): if restartArguments: command_dict["arguments"] = restartArguments - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) # Caller must still call wait_for_stopped. return response @@ -825,7 +918,7 @@ def request_disconnect(self, terminateDebuggee=None): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_disassemble( self, @@ -845,7 +938,7 @@ def request_disassemble( "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict)["body"]["instructions"] + return self._send_recv(command_dict)["body"]["instructions"] def request_readMemory(self, memoryReference, offset, count): args_dict = { @@ -858,7 +951,7 @@ def request_readMemory(self, memoryReference, offset, count): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_writeMemory(self, memoryReference, data, offset=0, allowPartial=False): args_dict = { @@ -876,7 +969,7 @@ def request_writeMemory(self, memoryReference, data, offset=0, allowPartial=Fals "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_evaluate(self, expression, frameIndex=0, threadId=None, context=None): stackFrame = self.get_stackFrame(frameIndex=frameIndex, threadId=threadId) @@ -892,7 +985,7 @@ def request_evaluate(self, expression, frameIndex=0, threadId=None, context=None "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_exceptionInfo(self, threadId=None): if threadId is None: @@ -903,7 +996,7 @@ def request_exceptionInfo(self, threadId=None): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_initialize(self, sourceInitFile=False): command_dict = { @@ -924,10 +1017,10 @@ def request_initialize(self, sourceInitFile=False): "$__lldb_sourceInitFile": sourceInitFile, }, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response: if "body" in response: - self.capabilities = response["body"] + self.capabilities.update(response.get("body", {})) return response def request_launch( @@ -1007,14 +1100,14 @@ def request_launch( if commandEscapePrefix is not None: args_dict["commandEscapePrefix"] = commandEscapePrefix command_dict = {"command": "launch", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_next(self, threadId, granularity="statement"): if self.exit_status is not None: raise ValueError("request_continue called after process exited") args_dict = {"threadId": threadId, "granularity": granularity} command_dict = {"command": "next", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stepIn(self, threadId, targetId, granularity="statement"): if self.exit_status is not None: @@ -1027,7 +1120,7 @@ def request_stepIn(self, threadId, targetId, granularity="statement"): "granularity": granularity, } command_dict = {"command": "stepIn", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stepInTargets(self, frameId): if self.exit_status is not None: @@ -1039,14 +1132,14 @@ def request_stepInTargets(self, frameId): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_stepOut(self, threadId): if self.exit_status is not None: raise ValueError("request_stepOut called after process exited") args_dict = {"threadId": threadId} command_dict = {"command": "stepOut", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_pause(self, threadId=None): if self.exit_status is not None: @@ -1055,12 +1148,12 @@ def request_pause(self, threadId=None): threadId = self.get_thread_id() args_dict = {"threadId": threadId} command_dict = {"command": "pause", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_scopes(self, frameId): args_dict = {"frameId": frameId} command_dict = {"command": "scopes", "type": "request", "arguments": args_dict} - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setBreakpoints(self, source: Source, line_array, data=None): """data is array of parameters for breakpoints in line_array. @@ -1068,7 +1161,7 @@ def request_setBreakpoints(self, source: Source, line_array, data=None): It contains optional location/hitCondition/logMessage parameters. """ args_dict = { - "source": source.as_dict(), + "source": source, "sourceModified": False, } if line_array is not None: @@ -1096,7 +1189,7 @@ def request_setBreakpoints(self, source: Source, line_array, data=None): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response["success"]: self._update_verified_breakpoints(response["body"]["breakpoints"]) return response @@ -1112,7 +1205,7 @@ def request_setExceptionBreakpoints( "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setFunctionBreakpoints(self, names, condition=None, hitCondition=None): breakpoints = [] @@ -1129,7 +1222,7 @@ def request_setFunctionBreakpoints(self, names, condition=None, hitCondition=Non "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if response["success"]: self._update_verified_breakpoints(response["body"]["breakpoints"]) return response @@ -1150,7 +1243,7 @@ def request_dataBreakpointInfo( "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setDataBreakpoint(self, dataBreakpoints): """dataBreakpoints is a list of dictionary with following fields: @@ -1167,7 +1260,7 @@ def request_setDataBreakpoint(self, dataBreakpoints): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_compileUnits(self, moduleId): args_dict = {"moduleId": moduleId} @@ -1176,7 +1269,7 @@ def request_compileUnits(self, moduleId): "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) return response def request_completions(self, text, frameId=None): @@ -1188,17 +1281,43 @@ def request_completions(self, text, frameId=None): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) - - def request_modules(self, startModule: int, moduleCount: int): - return self.send_recv( - { - "command": "modules", - "type": "request", - "arguments": {"startModule": startModule, "moduleCount": moduleCount}, - } + return self._send_recv(command_dict) + + def request_modules( + self, + start_module: Optional[int] = None, + module_count: Optional[int] = None, + ): + args_dict = {} + + if start_module is not None: + args_dict["startModule"] = start_module + if module_count is not None: + args_dict["moduleCount"] = module_count + + return self._send_recv( + {"command": "modules", "type": "request", "arguments": args_dict} ) + def request_moduleSymbols( + self, + moduleId: str = "", + moduleName: str = "", + startIndex: int = 0, + count: int = 0, + ): + command_dict = { + "command": "__lldb_moduleSymbols", + "type": "request", + "arguments": { + "moduleId": moduleId, + "moduleName": moduleName, + "startIndex": startIndex, + "count": count, + }, + } + return self._send_recv(command_dict) + def request_stackTrace( self, threadId=None, startFrame=None, levels=None, format=None, dump=False ): @@ -1216,7 +1335,7 @@ def request_stackTrace( "type": "request", "arguments": args_dict, } - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if dump: for idx, frame in enumerate(response["body"]["stackFrames"]): name = frame["name"] @@ -1231,18 +1350,30 @@ def request_stackTrace( print("[%3u] %s" % (idx, name)) return response - def request_source(self, sourceReference): + def request_source( + self, *, source: Optional[Source] = None, sourceReference: Optional[int] = None + ): """Request a source from a 'Source' reference.""" + if source is None and sourceReference is None: + raise ValueError("request_source requires either source or sourceReference") + elif source is not None: + sourceReference = source["sourceReference"] + elif sourceReference is not None: + source = {"sourceReference": sourceReference} + else: + raise ValueError( + "request_source requires either source or sourceReference not both" + ) command_dict = { "command": "source", "type": "request", "arguments": { - "source": {"sourceReference": sourceReference}, + "source": source, # legacy version of the request "sourceReference": sourceReference, }, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_threads(self): """Request a list of all threads and combine any information from any @@ -1250,7 +1381,7 @@ def request_threads(self): thread actually stopped. Returns an array of thread dictionaries with information about all threads""" command_dict = {"command": "threads", "type": "request", "arguments": {}} - response = self.send_recv(command_dict) + response = self._send_recv(command_dict) if not response["success"]: self.threads = None return response @@ -1290,7 +1421,7 @@ def request_variables( "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_setVariable(self, containingVarRef, name, value, id=None): args_dict = { @@ -1305,7 +1436,7 @@ def request_setVariable(self, containingVarRef, name, value, id=None): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_locations(self, locationReference): args_dict = { @@ -1316,7 +1447,7 @@ def request_locations(self, locationReference): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def request_testGetTargetBreakpoints(self): """A request packet used in the LLDB test suite to get all currently @@ -1328,12 +1459,12 @@ def request_testGetTargetBreakpoints(self): "type": "request", "arguments": {}, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) def terminate(self): self.send.close() - if self.recv_thread.is_alive(): - self.recv_thread.join() + if self._recv_thread.is_alive(): + self._recv_thread.join() def request_setInstructionBreakpoints(self, memory_reference=[]): breakpoints = [] @@ -1348,7 +1479,7 @@ def request_setInstructionBreakpoints(self, memory_reference=[]): "type": "request", "arguments": args_dict, } - return self.send_recv(command_dict) + return self._send_recv(command_dict) class DebugAdapterServer(DebugCommunication): diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py index c51b4b1892951..158cd7d938e09 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py @@ -1,16 +1,20 @@ import os import time -from typing import Optional +from typing import Optional, Callable, Any, List, Union import uuid import dap_server from dap_server import Source +from lldbsuite.test.decorators import skipIf from lldbsuite.test.lldbtest import * from lldbsuite.test import lldbplatformutil import lldbgdbserverutils import base64 +# DAP tests as a whole have been flakey on the Windows on Arm bot. See: +# https://github.com/llvm/llvm-project/issues/137660 +@skipIf(oslist=["windows"], archs=["aarch64"]) class DAPTestCaseBase(TestBase): # set timeout based on whether ASAN was enabled or not. Increase # timeout by a factor of 10 if ASAN is enabled. @@ -67,7 +71,10 @@ def set_source_breakpoints_assembly( self, source_reference, lines, data=None, wait_for_resolve=True ): return self.set_source_breakpoints_from_source( - Source(source_reference=source_reference), lines, data, wait_for_resolve + Source.build(source_reference=source_reference), + lines, + data, + wait_for_resolve, ) def set_source_breakpoints_from_source( @@ -120,11 +127,19 @@ def wait_for_breakpoints_to_resolve( f"Expected to resolve all breakpoints. Unresolved breakpoint ids: {unresolved_breakpoints}", ) - def waitUntil(self, condition_callback): - for _ in range(20): - if condition_callback(): + def wait_until( + self, + predicate: Callable[[], bool], + delay: float = 0.5, + timeout: float = DEFAULT_TIMEOUT, + ) -> bool: + """Repeatedly run the predicate until either the predicate returns True + or a timeout has occurred.""" + deadline = time.monotonic() + timeout + while deadline > time.monotonic(): + if predicate(): return True - time.sleep(0.5) + time.sleep(delay) return False def assertCapabilityIsSet(self, key: str, msg: Optional[str] = None) -> None: @@ -137,13 +152,16 @@ def assertCapabilityIsNotSet(self, key: str, msg: Optional[str] = None) -> None: if key in self.dap_server.capabilities: self.assertEqual(self.dap_server.capabilities[key], False, msg) - def verify_breakpoint_hit(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): + def verify_breakpoint_hit( + self, breakpoint_ids: List[Union[int, str]], timeout: float = DEFAULT_TIMEOUT + ): """Wait for the process we are debugging to stop, and verify we hit any breakpoint location in the "breakpoint_ids" array. "breakpoint_ids" should be a list of breakpoint ID strings (["1", "2"]). The return value from self.set_source_breakpoints() or self.set_function_breakpoints() can be passed to this function""" stopped_events = self.dap_server.wait_for_stopped(timeout) + normalized_bp_ids = [str(b) for b in breakpoint_ids] for stopped_event in stopped_events: if "body" in stopped_event: body = stopped_event["body"] @@ -154,22 +172,16 @@ def verify_breakpoint_hit(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): and body["reason"] != "instruction breakpoint" ): continue - if "description" not in body: + if "hitBreakpointIds" not in body: continue - # Descriptions for breakpoints will be in the form - # "breakpoint 1.1", so look for any description that matches - # ("breakpoint 1.") in the description field as verification - # that one of the breakpoint locations was hit. DAP doesn't - # allow breakpoints to have multiple locations, but LLDB does. - # So when looking at the description we just want to make sure - # the right breakpoint matches and not worry about the actual - # location. - description = body["description"] - for breakpoint_id in breakpoint_ids: - match_desc = f"breakpoint {breakpoint_id}." - if match_desc in description: + hit_breakpoint_ids = body["hitBreakpointIds"] + for bp in hit_breakpoint_ids: + if str(bp) in normalized_bp_ids: return - self.assertTrue(False, f"breakpoint not hit, stopped_events={stopped_events}") + self.assertTrue( + False, + f"breakpoint not hit, wanted breakpoint_ids {breakpoint_ids} in stopped_events {stopped_events}", + ) def verify_all_breakpoints_hit(self, breakpoint_ids, timeout=DEFAULT_TIMEOUT): """Wait for the process we are debugging to stop, and verify we hit @@ -213,7 +225,7 @@ def verify_stop_exception_info(self, expected_description, timeout=DEFAULT_TIMEO return True return False - def verify_commands(self, flavor, output, commands): + def verify_commands(self, flavor: str, output: str, commands: list[str]): self.assertTrue(output and len(output) > 0, "expect console output") lines = output.splitlines() prefix = "(lldb) " @@ -226,10 +238,11 @@ def verify_commands(self, flavor, output, commands): found = True break self.assertTrue( - found, "verify '%s' found in console output for '%s'" % (cmd, flavor) + found, + f"Command '{flavor}' - '{cmd}' not found in output: {output}", ) - def get_dict_value(self, d, key_path): + def get_dict_value(self, d: dict, key_path: list[str]) -> Any: """Verify each key in the key_path array is in contained in each dictionary within "d". Assert if any key isn't in the corresponding dictionary. This is handy for grabbing values from VS @@ -298,28 +311,34 @@ def get_source_and_line(self, threadId=None, frameIndex=0): return (source["path"], stackFrame["line"]) return ("", 0) - def get_stdout(self, timeout=0.0): - return self.dap_server.get_output("stdout", timeout=timeout) + def get_stdout(self): + return self.dap_server.get_output("stdout") - def get_console(self, timeout=0.0): - return self.dap_server.get_output("console", timeout=timeout) + def get_console(self): + return self.dap_server.get_output("console") - def get_important(self, timeout=0.0): - return self.dap_server.get_output("important", timeout=timeout) + def get_important(self): + return self.dap_server.get_output("important") - def collect_stdout(self, timeout_secs, pattern=None): + def collect_stdout( + self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None + ) -> str: return self.dap_server.collect_output( - "stdout", timeout_secs=timeout_secs, pattern=pattern + "stdout", timeout=timeout, pattern=pattern ) - def collect_console(self, timeout_secs, pattern=None): + def collect_console( + self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None + ) -> str: return self.dap_server.collect_output( - "console", timeout_secs=timeout_secs, pattern=pattern + "console", timeout=timeout, pattern=pattern ) - def collect_important(self, timeout_secs, pattern=None): + def collect_important( + self, timeout: float = DEFAULT_TIMEOUT, pattern: Optional[str] = None + ) -> str: return self.dap_server.collect_output( - "important", timeout_secs=timeout_secs, pattern=pattern + "important", timeout=timeout, pattern=pattern ) def get_local_as_int(self, name, threadId=None): diff --git a/lldb/source/API/SBSymbol.cpp b/lldb/source/API/SBSymbol.cpp index 79477dd3a70fc..3b59119494f37 100644 --- a/lldb/source/API/SBSymbol.cpp +++ b/lldb/source/API/SBSymbol.cpp @@ -193,6 +193,14 @@ SymbolType SBSymbol::GetType() { return eSymbolTypeInvalid; } +uint32_t SBSymbol::GetID() { + LLDB_INSTRUMENT_VA(this); + + if (m_opaque_ptr) + return m_opaque_ptr->GetID(); + return 0; +} + bool SBSymbol::IsExternal() { LLDB_INSTRUMENT_VA(this); @@ -208,3 +216,23 @@ bool SBSymbol::IsSynthetic() { return m_opaque_ptr->IsSynthetic(); return false; } + +bool SBSymbol::IsDebug() { + LLDB_INSTRUMENT_VA(this); + + if (m_opaque_ptr) + return m_opaque_ptr->IsDebug(); + return false; +} + +const char *SBSymbol::GetTypeAsString(lldb::SymbolType symbol_type) { + LLDB_INSTRUMENT_VA(symbol_type); + + return Symbol::GetTypeAsString(symbol_type); +} + +lldb::SymbolType SBSymbol::GetTypeFromString(const char *str) { + LLDB_INSTRUMENT_VA(str); + + return Symbol::GetTypeFromString(str); +} diff --git a/lldb/source/API/SBTarget.cpp b/lldb/source/API/SBTarget.cpp index 34f3c261719b2..3e5eaa2e582df 100644 --- a/lldb/source/API/SBTarget.cpp +++ b/lldb/source/API/SBTarget.cpp @@ -1606,6 +1606,18 @@ SBModule SBTarget::FindModule(const SBFileSpec &sb_file_spec) { return sb_module; } +SBModule SBTarget::FindModule(const SBModuleSpec &sb_module_spec) { + LLDB_INSTRUMENT_VA(this, sb_module_spec); + + SBModule sb_module; + if (TargetSP target_sp = GetSP(); target_sp && sb_module_spec.IsValid()) { + // The module list is thread safe, no need to lock. + sb_module.SetSP( + target_sp->GetImages().FindFirstModule(*sb_module_spec.m_opaque_up)); + } + return sb_module; +} + SBSymbolContextList SBTarget::FindCompileUnits(const SBFileSpec &sb_file_spec) { LLDB_INSTRUMENT_VA(this, sb_file_spec); diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index 546c12c8f7114..c4b42eafc85d3 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -7,171 +7,26 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" -#include "lldb/Utility/IOObject.h" -#include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" -#include "lldb/Utility/SelectHelper.h" #include "lldb/Utility/Status.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" -#include #include -#include using namespace llvm; using namespace lldb; using namespace lldb_private; -/// ReadFull attempts to read the specified number of bytes. If EOF is -/// encountered, an empty string is returned. -static Expected -ReadFull(IOObject &descriptor, size_t length, - std::optional timeout = std::nullopt) { - if (!descriptor.IsValid()) - return llvm::make_error(); +char TransportUnhandledContentsError::ID; - bool timeout_supported = true; - // FIXME: SelectHelper does not work with NativeFile on Win32. -#if _WIN32 - timeout_supported = descriptor.GetFdType() == IOObject::eFDTypeSocket; -#endif +TransportUnhandledContentsError::TransportUnhandledContentsError( + std::string unhandled_contents) + : m_unhandled_contents(unhandled_contents) {} - if (timeout && timeout_supported) { - SelectHelper sh; - sh.SetTimeout(*timeout); - sh.FDSetRead( - reinterpret_cast(descriptor.GetWaitableHandle())); - Status status = sh.Select(); - if (status.Fail()) { - // Convert timeouts into a specific error. - if (status.GetType() == lldb::eErrorTypePOSIX && - status.GetError() == ETIMEDOUT) - return make_error(); - return status.takeError(); - } - } - - std::string data; - data.resize(length); - Status status = descriptor.Read(data.data(), length); - if (status.Fail()) - return status.takeError(); - - // Read returns '' on EOF. - if (length == 0) - return make_error(); - - // Return the actual number of bytes read. - return data.substr(0, length); -} - -static Expected -ReadUntil(IOObject &descriptor, StringRef delimiter, - std::optional timeout = std::nullopt) { - std::string buffer; - buffer.reserve(delimiter.size() + 1); - while (!llvm::StringRef(buffer).ends_with(delimiter)) { - Expected next = - ReadFull(descriptor, buffer.empty() ? delimiter.size() : 1, timeout); - if (auto Err = next.takeError()) - return std::move(Err); - buffer += *next; - } - return buffer.substr(0, buffer.size() - delimiter.size()); -} - -JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) - : m_input(std::move(input)), m_output(std::move(output)) {} - -void JSONTransport::Log(llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); +void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { + OS << "transport EOF with unhandled contents: '" << m_unhandled_contents + << "'"; } - -Expected -HTTPDelimitedJSONTransport::ReadImpl(const std::chrono::microseconds &timeout) { - if (!m_input || !m_input->IsValid()) - return llvm::make_error(); - - IOObject *input = m_input.get(); - Expected message_header = - ReadFull(*input, kHeaderContentLength.size(), timeout); - if (!message_header) - return message_header.takeError(); - if (*message_header != kHeaderContentLength) - return createStringError(formatv("expected '{0}' and got '{1}'", - kHeaderContentLength, *message_header) - .str()); - - Expected raw_length = ReadUntil(*input, kHeaderSeparator); - if (!raw_length) - return handleErrors(raw_length.takeError(), - [&](const TransportEOFError &E) -> llvm::Error { - return createStringError( - "unexpected EOF while reading header separator"); - }); - - size_t length; - if (!to_integer(*raw_length, length)) - return createStringError( - formatv("invalid content length {0}", *raw_length).str()); - - Expected raw_json = ReadFull(*input, length); - if (!raw_json) - return handleErrors( - raw_json.takeError(), [&](const TransportEOFError &E) -> llvm::Error { - return createStringError("unexpected EOF while reading JSON"); - }); - - Log(llvm::formatv("--> {0}", *raw_json).str()); - - return raw_json; +std::error_code TransportUnhandledContentsError::convertToErrorCode() const { + return std::make_error_code(std::errc::bad_message); } - -Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error(); - - Log(llvm::formatv("<-- {0}", message).str()); - - std::string Output; - raw_string_ostream OS(Output); - OS << kHeaderContentLength << message.length() << kHeaderSeparator << message; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -Expected -JSONRPCTransport::ReadImpl(const std::chrono::microseconds &timeout) { - if (!m_input || !m_input->IsValid()) - return make_error(); - - IOObject *input = m_input.get(); - Expected raw_json = - ReadUntil(*input, kMessageSeparator, timeout); - if (!raw_json) - return raw_json.takeError(); - - Log(llvm::formatv("--> {0}", *raw_json).str()); - - return *raw_json; -} - -Error JSONRPCTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error(); - - Log(llvm::formatv("<-- {0}", message).str()); - - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << message << kMessageSeparator; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -char TransportEOFError::ID; -char TransportTimeoutError::ID; -char TransportInvalidError::ID; diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index c359663239dcc..57132534cf680 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -26,24 +26,10 @@ using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) -static constexpr size_t kChunkSize = 1024; static constexpr llvm::StringLiteral kName = "lldb-mcp"; static constexpr llvm::StringLiteral kVersion = "0.1.0"; -ProtocolServerMCP::ProtocolServerMCP() - : ProtocolServer(), - lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { - AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); - - AddTool( - std::make_unique("lldb_command", "Run an lldb command.")); - - AddResourceProvider(std::make_unique()); -} +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {} ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -64,57 +50,37 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } +void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { + server.AddNotificationHandler("notifications/initialized", + [](const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), + "MCP initialization complete"); + }); + server.AddTool( + std::make_unique("lldb_command", "Run an lldb command.")); + server.AddResourceProvider(std::make_unique()); +} + void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { - LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", - m_clients.size() + 1); + Log *log = GetLog(LLDBLog::Host); + std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); - auto client_up = std::make_unique(); - client_up->io_sp = io_sp; - Client *client = client_up.get(); - - Status status; - auto read_handle_up = m_loop.RegisterReadObject( - io_sp, - [this, client](MainLoopBase &loop) { - if (llvm::Error error = ReadCallback(*client)) { - LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); - client->read_handle_up.reset(); - } - }, - status); - if (status.Fail()) + auto transport_up = std::make_unique( + io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); + }); + auto instance_up = std::make_unique( + std::string(kName), std::string(kVersion), std::move(transport_up), + m_loop); + Extend(*instance_up); + llvm::Error error = instance_up->Run(); + if (error) { + LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); return; - - client_up->read_handle_up = std::move(read_handle_up); - m_clients.emplace_back(std::move(client_up)); -} - -llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { - char chunk[kChunkSize]; - size_t bytes_read = sizeof(chunk); - if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) - return status.takeError(); - client.buffer.append(chunk, bytes_read); - - for (std::string::size_type pos; - (pos = client.buffer.find('\n')) != std::string::npos;) { - llvm::Expected> message = - HandleData(StringRef(client.buffer.data(), pos)); - client.buffer = client.buffer.erase(0, pos + 1); - if (!message) - return message.takeError(); - - if (*message) { - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; - size_t num_bytes = Output.size(); - return client.io_sp->Write(Output.data(), num_bytes).takeError(); - } } - - return llvm::Error::success(); + m_instances.push_back(std::move(instance_up)); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -158,27 +124,11 @@ llvm::Error ProtocolServerMCP::Stop() { // Stop the main loop. m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); // Wait for the main loop to exit. if (m_loop_thread.joinable()) m_loop_thread.join(); - { - std::lock_guard guard(m_mutex); - m_listener.reset(); - m_listen_handlers.clear(); - m_clients.clear(); - } - return llvm::Error::success(); } - -lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; - // FIXME: Support sending notifications when a debugger/target are - // added/removed. - capabilities.resources.listChanged = false; - return capabilities; -} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 7fe909a728b85..fc650ffe0dfa7 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -18,8 +18,7 @@ namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer, - public lldb_protocol::mcp::Server { +class ProtocolServerMCP : public ProtocolServer { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -39,26 +38,24 @@ class ProtocolServerMCP : public ProtocolServer, Socket *GetSocket() const override { return m_listener.get(); } +protected: + // This adds tools and resource providers that + // are specific to this server. Overridable by the unit tests. + virtual void Extend(lldb_protocol::mcp::Server &server) const; + private: void AcceptCallback(std::unique_ptr socket); - lldb_protocol::mcp::Capabilities GetCapabilities() override; - bool m_running = false; - MainLoop m_loop; + lldb_private::MainLoop m_loop; std::thread m_loop_thread; + std::mutex m_mutex; std::unique_ptr m_listener; - std::vector m_listen_handlers; - struct Client { - lldb::IOObjectSP io_sp; - MainLoopBase::ReadHandleUP read_handle_up; - std::string buffer; - }; - llvm::Error ReadCallback(Client &client); - std::vector> m_clients; + std::vector m_listen_handlers; + std::vector> m_instances; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp index e94d2cdd65e07..581424510d4cf 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -8,7 +8,6 @@ #include "lldb/Core/Debugger.h" #include "lldb/Core/Module.h" #include "lldb/Protocol/MCP/MCPError.h" -#include "lldb/Target/Platform.h" using namespace lldb_private; using namespace lldb_private::mcp; @@ -124,7 +123,7 @@ DebuggerResourceProvider::GetResources() const { return resources; } -llvm::Expected +llvm::Expected DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { auto [protocol, path] = uri.split("://"); @@ -161,7 +160,7 @@ DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { return ReadDebuggerResource(uri, debugger_idx); } -llvm::Expected +llvm::Expected DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id) { lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); @@ -173,17 +172,17 @@ DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, debugger_resource.name = debugger_sp->GetInstanceName(); debugger_resource.num_targets = debugger_sp->GetTargetList().GetNumTargets(); - lldb_protocol::mcp::ResourceContents contents; + lldb_protocol::mcp::TextResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ReadResourceResult result; result.contents.push_back(contents); return result; } -llvm::Expected +llvm::Expected DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx) { @@ -209,12 +208,12 @@ DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, if (lldb::PlatformSP platform_sp = target_sp->GetPlatform()) target_resource.platform = platform_sp->GetName(); - lldb_protocol::mcp::ResourceContents contents; + lldb_protocol::mcp::TextResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(target_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ReadResourceResult result; result.contents.push_back(contents); return result; } diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h index e2382a74f796b..0c6576602905e 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.h +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -11,7 +11,11 @@ #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" -#include "lldb/lldb-private.h" +#include "lldb/lldb-forward.h" +#include "lldb/lldb-types.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include #include namespace lldb_private::mcp { @@ -21,9 +25,8 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { using ResourceProvider::ResourceProvider; virtual ~DebuggerResourceProvider() = default; - virtual std::vector - GetResources() const override; - virtual llvm::Expected + std::vector GetResources() const override; + llvm::Expected ReadResource(llvm::StringRef uri) const override; private: @@ -31,9 +34,9 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { static lldb_protocol::mcp::Resource GetTargetResource(size_t target_idx, Target &target); - static llvm::Expected + static llvm::Expected ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); - static llvm::Expected + static llvm::Expected ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx); }; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 143470702a6fd..2f451bf76e81d 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "Tool.h" -#include "lldb/Core/Module.h" #include "lldb/Interpreter/CommandInterpreter.h" #include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Protocol/MCP/Protocol.h" using namespace lldb_private; using namespace lldb_protocol; @@ -29,10 +29,10 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, O.mapOptional("arguments", A.arguments); } -/// Helper function to create a TextResult from a string output. -static lldb_protocol::mcp::TextResult createTextResult(std::string output, - bool is_error = false) { - lldb_protocol::mcp::TextResult text_result; +/// Helper function to create a CallToolResult from a string output. +static lldb_protocol::mcp::CallToolResult +createTextResult(std::string output, bool is_error = false) { + lldb_protocol::mcp::CallToolResult text_result; text_result.content.emplace_back( lldb_protocol::mcp::TextContent{{std::move(output)}}); text_result.isError = is_error; @@ -41,7 +41,7 @@ static lldb_protocol::mcp::TextResult createTextResult(std::string output, } // namespace -llvm::Expected +llvm::Expected CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) { if (!std::holds_alternative(args)) return createStringError("CommandTool requires arguments"); diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index b7b1756eb38d7..1886525b9168f 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -9,11 +9,11 @@ #ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H #define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H -#include "lldb/Core/Debugger.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Tool.h" +#include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" -#include +#include namespace lldb_private::mcp { @@ -22,10 +22,10 @@ class CommandTool : public lldb_protocol::mcp::Tool { using lldb_protocol::mcp::Tool::Tool; ~CommandTool() = default; - virtual llvm::Expected + llvm::Expected Call(const lldb_protocol::mcp::ToolArguments &args) override; - virtual std::optional GetSchema() const override; + std::optional GetSchema() const override; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/CMakeLists.txt b/lldb/source/Protocol/MCP/CMakeLists.txt index a73e7e6a7cab1..a4f270a83c43b 100644 --- a/lldb/source/Protocol/MCP/CMakeLists.txt +++ b/lldb/source/Protocol/MCP/CMakeLists.txt @@ -7,6 +7,7 @@ add_lldb_library(lldbProtocolMCP NO_PLUGIN_DEPENDENCIES LINK_COMPONENTS Support LINK_LIBS + lldbHost lldbUtility ) diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index d9b11bd766686..0988f456adc26 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" using namespace llvm; @@ -166,32 +167,6 @@ bool operator==(const Notification &a, const Notification &b) { return a.method == b.method && a.params == b.params; } -llvm::json::Value toJSON(const ToolCapability &TC) { - return llvm::json::Object{{"listChanged", TC.listChanged}}; -} - -bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("listChanged", TC.listChanged); -} - -llvm::json::Value toJSON(const ResourceCapability &RC) { - return llvm::json::Object{{"listChanged", RC.listChanged}, - {"subscribe", RC.subscribe}}; -} - -bool fromJSON(const llvm::json::Value &V, ResourceCapability &RC, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("listChanged", RC.listChanged) && - O.map("subscribe", RC.subscribe); -} - -llvm::json::Value toJSON(const Capabilities &C) { - return llvm::json::Object{{"tools", C.tools}, {"resources", C.resources}}; -} - bool fromJSON(const llvm::json::Value &V, Resource &R, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("uri", R.uri) && O.map("name", R.name) && @@ -208,30 +183,25 @@ llvm::json::Value toJSON(const Resource &R) { return Result; } -bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("tools", C.tools); -} - -llvm::json::Value toJSON(const ResourceContents &RC) { +llvm::json::Value toJSON(const TextResourceContents &RC) { llvm::json::Object Result{{"uri", RC.uri}, {"text", RC.text}}; if (!RC.mimeType.empty()) Result.insert({"mimeType", RC.mimeType}); return Result; } -bool fromJSON(const llvm::json::Value &V, ResourceContents &RC, +bool fromJSON(const llvm::json::Value &V, TextResourceContents &RC, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("uri", RC.uri) && O.map("text", RC.text) && O.mapOptional("mimeType", RC.mimeType); } -llvm::json::Value toJSON(const ResourceResult &RR) { +llvm::json::Value toJSON(const ReadResourceResult &RR) { return llvm::json::Object{{"contents", RR.contents}}; } -bool fromJSON(const llvm::json::Value &V, ResourceResult &RR, +bool fromJSON(const llvm::json::Value &V, ReadResourceResult &RR, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("contents", RR.contents); @@ -246,15 +216,6 @@ bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { return O && O.map("text", TC.text); } -llvm::json::Value toJSON(const TextResult &TR) { - return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; -} - -bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("content", TR.content) && O.map("isError", TR.isError); -} - llvm::json::Value toJSON(const ToolDefinition &TD) { llvm::json::Object Result{{"name", TD.name}}; if (!TD.description.empty()) @@ -324,4 +285,159 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return false; } +json::Value toJSON(const Implementation &I) { + json::Object result{{"name", I.name}, {"version", I.version}}; + + if (!I.title.empty()) + result.insert({"title", I.title}); + + return result; +} + +bool fromJSON(const json::Value &V, Implementation &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", I.name) && O.mapOptional("title", I.title) && + O.mapOptional("version", I.version); +} + +json::Value toJSON(const ClientCapabilities &C) { return json::Object{}; } + +bool fromJSON(const json::Value &, ClientCapabilities &, json::Path) { + return true; +} + +json::Value toJSON(const ServerCapabilities &C) { + json::Object result{}; + + if (C.supportsToolsList) + result.insert({"tools", json::Object{{"listChanged", true}}}); + + if (C.supportsResourcesList || C.supportsResourcesSubscribe) { + json::Object resources; + if (C.supportsResourcesList) + resources.insert({"listChanged", true}); + if (C.supportsResourcesSubscribe) + resources.insert({"subscribe", true}); + result.insert({"resources", std::move(resources)}); + } + + if (C.supportsCompletions) + result.insert({"completions", json::Object{}}); + + if (C.supportsLogging) + result.insert({"logging", json::Object{}}); + + return result; +} + +bool fromJSON(const json::Value &V, ServerCapabilities &C, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (O->find("tools") != O->end()) + C.supportsToolsList = true; + + return true; +} + +json::Value toJSON(const InitializeParams &P) { + return json::Object{ + {"protocolVersion", P.protocolVersion}, + {"capabilities", P.capabilities}, + {"clientInfo", P.clientInfo}, + }; +} + +bool fromJSON(const json::Value &V, InitializeParams &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", I.protocolVersion) && + O.map("capabilities", I.capabilities) && + O.map("clientInfo", I.clientInfo); +} + +json::Value toJSON(const InitializeResult &R) { + json::Object result{{"protocolVersion", R.protocolVersion}, + {"capabilities", R.capabilities}, + {"serverInfo", R.serverInfo}}; + + if (!R.instructions.empty()) + result.insert({"instructions", R.instructions}); + + return result; +} + +bool fromJSON(const json::Value &V, InitializeResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", R.protocolVersion) && + O.map("capabilities", R.capabilities) && + O.map("serverInfo", R.serverInfo) && + O.mapOptional("instructions", R.instructions); +} + +json::Value toJSON(const ListToolsResult &R) { + return json::Object{{"tools", R.tools}}; +} + +bool fromJSON(const json::Value &V, ListToolsResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("tools", R.tools); +} + +json::Value toJSON(const CallToolResult &R) { + json::Object result{{"content", R.content}}; + + if (R.isError) + result.insert({"isError", R.isError}); + if (R.structuredContent) + result.insert({"structuredContent", *R.structuredContent}); + + return result; +} + +bool fromJSON(const json::Value &V, CallToolResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("content", R.content) && + O.mapOptional("isError", R.isError) && + mapRaw(V, "structuredContent", R.structuredContent, P); +} + +json::Value toJSON(const CallToolParams &R) { + json::Object result{{"name", R.name}}; + + if (R.arguments) + result.insert({"arguments", *R.arguments}); + + return result; +} + +bool fromJSON(const json::Value &V, CallToolParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", R.name) && mapRaw(V, "arguments", R.arguments, P); +} + +json::Value toJSON(const ReadResourceParams &R) { + return json::Object{{"uri", R.uri}}; +} + +bool fromJSON(const json::Value &V, ReadResourceParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("uri", R.uri); +} + +json::Value toJSON(const ListResourcesResult &R) { + return json::Object{{"resources", R.resources}}; +} + +bool fromJSON(const json::Value &V, ListResourcesResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resources", R.resources); +} + +json::Value toJSON(const Void &R) { return json::Object{}; } + +bool fromJSON(const json::Value &V, Void &R, json::Path P) { return true; } + } // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index a9c1482e3e378..63c2d01d17922 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -8,12 +8,17 @@ #include "lldb/Protocol/MCP/Server.h" #include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/JSON.h" using namespace lldb_protocol::mcp; using namespace llvm; -Server::Server(std::string name, std::string version) - : m_name(std::move(name)), m_version(std::move(version)) { +Server::Server(std::string name, std::string version, + std::unique_ptr transport_up, + lldb_private::MainLoop &loop) + : m_name(std::move(name)), m_version(std::move(version)), + m_transport_up(std::move(transport_up)), m_loop(loop) { AddRequestHandlers(); } @@ -30,7 +35,7 @@ void Server::AddRequestHandlers() { this, std::placeholders::_1)); } -llvm::Expected Server::Handle(Request request) { +llvm::Expected Server::Handle(const Request &request) { auto it = m_request_handlers.find(request.method); if (it != m_request_handlers.end()) { llvm::Expected response = it->second(request); @@ -44,7 +49,7 @@ llvm::Expected Server::Handle(Request request) { llvm::formatv("no handler for request: {0}", request.method).str()); } -void Server::Handle(Notification notification) { +void Server::Handle(const Notification ¬ification) { auto it = m_notification_handlers.find(notification.method); if (it != m_notification_handlers.end()) { it->second(notification); @@ -52,49 +57,7 @@ void Server::Handle(Notification notification) { } } -llvm::Expected> -Server::HandleData(llvm::StringRef data) { - auto message = llvm::json::parse(/*JSON=*/data); - if (!message) - return message.takeError(); - - if (const Request *request = std::get_if(&(*message))) { - llvm::Expected response = Handle(*request); - - // Handle failures by converting them into an Error message. - if (!response) { - Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request->id; - error_response.result = std::move(protocol_error); - return error_response; - } - - return *response; - } - - if (const Notification *notification = - std::get_if(&(*message))) { - Handle(*notification); - return std::nullopt; - } - - if (std::get_if(&(*message))) - return llvm::createStringError("unexpected MCP message: response"); - - llvm_unreachable("all message types handled"); -} - void Server::AddTool(std::unique_ptr tool) { - std::lock_guard guard(m_mutex); - if (!tool) return; m_tools[tool->GetName()] = std::move(tool); @@ -102,42 +65,39 @@ void Server::AddTool(std::unique_ptr tool) { void Server::AddResourceProvider( std::unique_ptr resource_provider) { - std::lock_guard guard(m_mutex); - if (!resource_provider) return; m_resource_providers.push_back(std::move(resource_provider)); } void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - std::lock_guard guard(m_mutex); m_request_handlers[method] = std::move(handler); } void Server::AddNotificationHandler(llvm::StringRef method, NotificationHandler handler) { - std::lock_guard guard(m_mutex); m_notification_handlers[method] = std::move(handler); } llvm::Expected Server::InitializeHandler(const Request &request) { Response response; - response.result = llvm::json::Object{ - {"protocolVersion", mcp::kProtocolVersion}, - {"capabilities", GetCapabilities()}, - {"serverInfo", - llvm::json::Object{{"name", m_name}, {"version", m_version}}}}; + InitializeResult result; + result.protocolVersion = mcp::kProtocolVersion; + result.capabilities = GetCapabilities(); + result.serverInfo.name = m_name; + result.serverInfo.version = m_version; + response.result = std::move(result); return response; } llvm::Expected Server::ToolsListHandler(const Request &request) { Response response; - llvm::json::Array tools; + ListToolsResult result; for (const auto &tool : m_tools) - tools.emplace_back(toJSON(tool.second->GetDefinition())); + result.tools.emplace_back(tool.second->GetDefinition()); - response.result = llvm::json::Object{{"tools", std::move(tools)}}; + response.result = std::move(result); return response; } @@ -147,16 +107,12 @@ llvm::Expected Server::ToolsCallHandler(const Request &request) { if (!request.params) return llvm::createStringError("no tool parameters"); + CallToolParams params; + json::Path::Root root("params"); + if (!fromJSON(request.params, params, root)) + return root.getError(); - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no tool parameters"); - - const json::Value *name = param_obj->get("name"); - if (!name) - return llvm::createStringError("no tool name"); - - llvm::StringRef tool_name = name->getAsString().value_or(""); + llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -165,10 +121,10 @@ llvm::Expected Server::ToolsCallHandler(const Request &request) { return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); ToolArguments tool_args; - if (const json::Value *args = param_obj->get("arguments")) - tool_args = *args; + if (params.arguments) + tool_args = *params.arguments; - llvm::Expected text_result = it->second->Call(tool_args); + llvm::Expected text_result = it->second->Call(tool_args); if (!text_result) return text_result.takeError(); @@ -180,15 +136,13 @@ llvm::Expected Server::ToolsCallHandler(const Request &request) { llvm::Expected Server::ResourcesListHandler(const Request &request) { Response response; - llvm::json::Array resources; - - std::lock_guard guard(m_mutex); + ListResourcesResult result; for (std::unique_ptr &resource_provider_up : - m_resource_providers) { + m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) - resources.push_back(resource); - } - response.result = llvm::json::Object{{"resources", std::move(resources)}}; + result.resources.push_back(resource); + + response.result = std::move(result); return response; } @@ -199,22 +153,18 @@ llvm::Expected Server::ResourcesReadHandler(const Request &request) { if (!request.params) return llvm::createStringError("no resource parameters"); - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no resource parameters"); + ReadResourceParams params; + json::Path::Root root("params"); + if (!fromJSON(request.params, params, root)) + return root.getError(); - const json::Value *uri = param_obj->get("uri"); - if (!uri) - return llvm::createStringError("no resource uri"); - - llvm::StringRef uri_str = uri->getAsString().value_or(""); + llvm::StringRef uri_str = params.uri; if (uri_str.empty()) return llvm::createStringError("no resource uri"); - std::lock_guard guard(m_mutex); for (std::unique_ptr &resource_provider_up : m_resource_providers) { - llvm::Expected result = + llvm::Expected result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA()) { llvm::consumeError(result.takeError()); @@ -232,3 +182,71 @@ llvm::Expected Server::ResourcesReadHandler(const Request &request) { llvm::formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } + +ServerCapabilities Server::GetCapabilities() { + lldb_protocol::mcp::ServerCapabilities capabilities; + capabilities.supportsToolsList = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.supportsResourcesList = false; + return capabilities; +} + +llvm::Error Server::Run() { + auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); + if (!handle) + return handle.takeError(); + + lldb_private::Status status = m_loop.Run(); + if (status.Fail()) + return status.takeError(); + + return llvm::Error::success(); +} + +void Server::Received(const Request &request) { + auto SendResponse = [this](const Response &response) { + if (llvm::Error error = m_transport_up->Send(response)) + m_transport_up->Log(llvm::toString(std::move(error))); + }; + + llvm::Expected response = Handle(request); + if (response) + return SendResponse(*response); + + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + Response error_response; + error_response.id = request.id; + error_response.result = std::move(protocol_error); + SendResponse(error_response); +} + +void Server::Received(const Response &response) { + m_transport_up->Log("unexpected MCP message: response"); +} + +void Server::Received(const Notification ¬ification) { + Handle(notification); +} + +void Server::OnError(llvm::Error error) { + m_transport_up->Log(llvm::toString(std::move(error))); + TerminateLoop(); +} + +void Server::OnClosed() { + m_transport_up->Log("EOF"); + TerminateLoop(); +} + +void Server::TerminateLoop() { + m_loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); +} diff --git a/lldb/source/Symbol/Symbol.cpp b/lldb/source/Symbol/Symbol.cpp index ff6f34f51c325..679edc381b1bf 100644 --- a/lldb/source/Symbol/Symbol.cpp +++ b/lldb/source/Symbol/Symbol.cpp @@ -392,45 +392,8 @@ bool Symbol::Compare(ConstString name, SymbolType type) const { return false; } -#define ENUM_TO_CSTRING(x) \ - case eSymbolType##x: \ - return #x; - const char *Symbol::GetTypeAsString() const { - switch (m_type) { - ENUM_TO_CSTRING(Invalid); - ENUM_TO_CSTRING(Absolute); - ENUM_TO_CSTRING(Code); - ENUM_TO_CSTRING(Resolver); - ENUM_TO_CSTRING(Data); - ENUM_TO_CSTRING(Trampoline); - ENUM_TO_CSTRING(Runtime); - ENUM_TO_CSTRING(Exception); - ENUM_TO_CSTRING(SourceFile); - ENUM_TO_CSTRING(HeaderFile); - ENUM_TO_CSTRING(ObjectFile); - ENUM_TO_CSTRING(CommonBlock); - ENUM_TO_CSTRING(Block); - ENUM_TO_CSTRING(Local); - ENUM_TO_CSTRING(Param); - ENUM_TO_CSTRING(Variable); - ENUM_TO_CSTRING(VariableType); - ENUM_TO_CSTRING(LineEntry); - ENUM_TO_CSTRING(LineHeader); - ENUM_TO_CSTRING(ScopeBegin); - ENUM_TO_CSTRING(ScopeEnd); - ENUM_TO_CSTRING(Additional); - ENUM_TO_CSTRING(Compiler); - ENUM_TO_CSTRING(Instrumentation); - ENUM_TO_CSTRING(Undefined); - ENUM_TO_CSTRING(ObjCClass); - ENUM_TO_CSTRING(ObjCMetaClass); - ENUM_TO_CSTRING(ObjCIVar); - ENUM_TO_CSTRING(ReExported); - default: - break; - } - return ""; + return GetTypeAsString(static_cast(m_type)); } void Symbol::CalculateSymbolContext(SymbolContext *sc) { @@ -774,6 +737,79 @@ bool Symbol::operator==(const Symbol &rhs) const { return true; } +#define ENUM_TO_CSTRING(x) \ + case eSymbolType##x: \ + return #x; + +const char *Symbol::GetTypeAsString(lldb::SymbolType symbol_type) { + switch (symbol_type) { + ENUM_TO_CSTRING(Invalid); + ENUM_TO_CSTRING(Absolute); + ENUM_TO_CSTRING(Code); + ENUM_TO_CSTRING(Resolver); + ENUM_TO_CSTRING(Data); + ENUM_TO_CSTRING(Trampoline); + ENUM_TO_CSTRING(Runtime); + ENUM_TO_CSTRING(Exception); + ENUM_TO_CSTRING(SourceFile); + ENUM_TO_CSTRING(HeaderFile); + ENUM_TO_CSTRING(ObjectFile); + ENUM_TO_CSTRING(CommonBlock); + ENUM_TO_CSTRING(Block); + ENUM_TO_CSTRING(Local); + ENUM_TO_CSTRING(Param); + ENUM_TO_CSTRING(Variable); + ENUM_TO_CSTRING(VariableType); + ENUM_TO_CSTRING(LineEntry); + ENUM_TO_CSTRING(LineHeader); + ENUM_TO_CSTRING(ScopeBegin); + ENUM_TO_CSTRING(ScopeEnd); + ENUM_TO_CSTRING(Additional); + ENUM_TO_CSTRING(Compiler); + ENUM_TO_CSTRING(Instrumentation); + ENUM_TO_CSTRING(Undefined); + ENUM_TO_CSTRING(ObjCClass); + ENUM_TO_CSTRING(ObjCMetaClass); + ENUM_TO_CSTRING(ObjCIVar); + ENUM_TO_CSTRING(ReExported); + } + return ""; +} + +lldb::SymbolType Symbol::GetTypeFromString(const char *str) { + std::string str_lower = llvm::StringRef(str).lower(); + return llvm::StringSwitch(str_lower) + .Case("absolute", eSymbolTypeAbsolute) + .Case("code", eSymbolTypeCode) + .Case("resolver", eSymbolTypeResolver) + .Case("data", eSymbolTypeData) + .Case("trampoline", eSymbolTypeTrampoline) + .Case("runtime", eSymbolTypeRuntime) + .Case("exception", eSymbolTypeException) + .Case("sourcefile", eSymbolTypeSourceFile) + .Case("headerfile", eSymbolTypeHeaderFile) + .Case("objectfile", eSymbolTypeObjectFile) + .Case("commonblock", eSymbolTypeCommonBlock) + .Case("block", eSymbolTypeBlock) + .Case("local", eSymbolTypeLocal) + .Case("param", eSymbolTypeParam) + .Case("variable", eSymbolTypeVariable) + .Case("variableType", eSymbolTypeVariableType) + .Case("lineentry", eSymbolTypeLineEntry) + .Case("lineheader", eSymbolTypeLineHeader) + .Case("scopebegin", eSymbolTypeScopeBegin) + .Case("scopeend", eSymbolTypeScopeEnd) + .Case("additional,", eSymbolTypeAdditional) + .Case("compiler", eSymbolTypeCompiler) + .Case("instrumentation", eSymbolTypeInstrumentation) + .Case("undefined", eSymbolTypeUndefined) + .Case("objcclass", eSymbolTypeObjCClass) + .Case("objcmetaclass", eSymbolTypeObjCMetaClass) + .Case("objcivar", eSymbolTypeObjCIVar) + .Case("reexported", eSymbolTypeReExported) + .Default(eSymbolTypeInvalid); +} + namespace llvm { namespace json { @@ -804,36 +840,8 @@ bool fromJSON(const llvm::json::Value &value, lldb_private::JSONSymbol &symbol, bool fromJSON(const llvm::json::Value &value, lldb::SymbolType &type, llvm::json::Path path) { if (auto str = value.getAsString()) { - type = llvm::StringSwitch(*str) - .Case("absolute", eSymbolTypeAbsolute) - .Case("code", eSymbolTypeCode) - .Case("resolver", eSymbolTypeResolver) - .Case("data", eSymbolTypeData) - .Case("trampoline", eSymbolTypeTrampoline) - .Case("runtime", eSymbolTypeRuntime) - .Case("exception", eSymbolTypeException) - .Case("sourcefile", eSymbolTypeSourceFile) - .Case("headerfile", eSymbolTypeHeaderFile) - .Case("objectfile", eSymbolTypeObjectFile) - .Case("commonblock", eSymbolTypeCommonBlock) - .Case("block", eSymbolTypeBlock) - .Case("local", eSymbolTypeLocal) - .Case("param", eSymbolTypeParam) - .Case("variable", eSymbolTypeVariable) - .Case("variableType", eSymbolTypeVariableType) - .Case("lineentry", eSymbolTypeLineEntry) - .Case("lineheader", eSymbolTypeLineHeader) - .Case("scopebegin", eSymbolTypeScopeBegin) - .Case("scopeend", eSymbolTypeScopeEnd) - .Case("additional,", eSymbolTypeAdditional) - .Case("compiler", eSymbolTypeCompiler) - .Case("instrumentation", eSymbolTypeInstrumentation) - .Case("undefined", eSymbolTypeUndefined) - .Case("objcclass", eSymbolTypeObjCClass) - .Case("objcmetaClass", eSymbolTypeObjCMetaClass) - .Case("objcivar", eSymbolTypeObjCIVar) - .Case("reexporte", eSymbolTypeReExported) - .Default(eSymbolTypeInvalid); + llvm::StringRef str_ref = str.value_or(""); + type = Symbol::GetTypeFromString(str_ref.data()); if (type == eSymbolTypeInvalid) { path.report("invalid symbol type"); diff --git a/lldb/test/API/tools/lldb-dap/attach/TestDAP_attach.py b/lldb/test/API/tools/lldb-dap/attach/TestDAP_attach.py index 55557e6e0030e..c54e21c1b973a 100644 --- a/lldb/test/API/tools/lldb-dap/attach/TestDAP_attach.py +++ b/lldb/test/API/tools/lldb-dap/attach/TestDAP_attach.py @@ -153,7 +153,7 @@ def test_commands(self): breakpoint_ids = self.set_function_breakpoints(functions) self.assertEqual(len(breakpoint_ids), len(functions), "expect one breakpoint") self.continue_to_breakpoints(breakpoint_ids) - output = self.collect_console(timeout_secs=10, pattern=stopCommands[-1]) + output = self.collect_console(timeout=10, pattern=stopCommands[-1]) self.verify_commands("stopCommands", output, stopCommands) # Continue after launch and hit the "pause()" call and stop the target. @@ -163,7 +163,7 @@ def test_commands(self): time.sleep(0.5) self.dap_server.request_pause() self.dap_server.wait_for_stopped() - output = self.collect_console(timeout_secs=10, pattern=stopCommands[-1]) + output = self.collect_console(timeout=10, pattern=stopCommands[-1]) self.verify_commands("stopCommands", output, stopCommands) # Continue until the program exits @@ -172,7 +172,7 @@ def test_commands(self): # "exitCommands" that were run after the second breakpoint was hit # and the "terminateCommands" due to the debugging session ending output = self.collect_console( - timeout_secs=10.0, + timeout=10.0, pattern=terminateCommands[0], ) self.verify_commands("exitCommands", output, exitCommands) @@ -223,7 +223,7 @@ def test_terminate_commands(self): # "terminateCommands" self.dap_server.request_disconnect(terminateDebuggee=True) output = self.collect_console( - timeout_secs=1.0, + timeout=1.0, pattern=terminateCommands[0], ) self.verify_commands("terminateCommands", output, terminateCommands) diff --git a/lldb/test/API/tools/lldb-dap/breakpoint-assembly/TestDAP_breakpointAssembly.py b/lldb/test/API/tools/lldb-dap/breakpoint-assembly/TestDAP_breakpointAssembly.py index 7552a77d2280a..fab109c93a17b 100644 --- a/lldb/test/API/tools/lldb-dap/breakpoint-assembly/TestDAP_breakpointAssembly.py +++ b/lldb/test/API/tools/lldb-dap/breakpoint-assembly/TestDAP_breakpointAssembly.py @@ -2,7 +2,6 @@ Test lldb-dap setBreakpoints request in assembly source references. """ - from lldbsuite.test.decorators import * from dap_server import Source import lldbdap_testcase @@ -52,7 +51,7 @@ def test_break_on_invalid_source_reference(self): # Verify that setting a breakpoint on an invalid source reference fails response = self.dap_server.request_setBreakpoints( - Source(source_reference=-1), [1] + Source.build(source_reference=-1), [1] ) self.assertIsNotNone(response) breakpoints = response["body"]["breakpoints"] @@ -69,7 +68,7 @@ def test_break_on_invalid_source_reference(self): # Verify that setting a breakpoint on a source reference that is not created fails response = self.dap_server.request_setBreakpoints( - Source(source_reference=200), [1] + Source.build(source_reference=200), [1] ) self.assertIsNotNone(response) breakpoints = response["body"]["breakpoints"] @@ -116,7 +115,7 @@ def test_persistent_assembly_breakpoint(self): persistent_breakpoint_source = self.dap_server.resolved_breakpoints[ persistent_breakpoint_ids[0] - ].source() + ]["source"] self.assertIn( "adapterData", persistent_breakpoint_source, @@ -139,7 +138,7 @@ def test_persistent_assembly_breakpoint(self): self.dap_server.request_initialize() self.dap_server.request_launch(program) new_session_breakpoints_ids = self.set_source_breakpoints_from_source( - Source(raw_dict=persistent_breakpoint_source), + Source(persistent_breakpoint_source), [persistent_breakpoint_line], ) diff --git a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py index 8f03244bc6572..151ad761a5044 100644 --- a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py +++ b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py @@ -58,7 +58,7 @@ def test_breakpoint_events(self): # Set breakpoints and verify that they got set correctly dap_breakpoint_ids = [] response = self.dap_server.request_setBreakpoints( - Source(main_source_path), [main_bp_line] + Source.build(path=main_source_path), [main_bp_line] ) self.assertTrue(response["success"]) breakpoints = response["body"]["breakpoints"] @@ -70,7 +70,7 @@ def test_breakpoint_events(self): ) response = self.dap_server.request_setBreakpoints( - Source(foo_source_path), [foo_bp1_line] + Source.build(path=foo_source_path), [foo_bp1_line] ) self.assertTrue(response["success"]) breakpoints = response["body"]["breakpoints"] diff --git a/lldb/test/API/tools/lldb-dap/breakpoint/TestDAP_setBreakpoints.py b/lldb/test/API/tools/lldb-dap/breakpoint/TestDAP_setBreakpoints.py index 2e860ff5d5e17..3309800c1dd10 100644 --- a/lldb/test/API/tools/lldb-dap/breakpoint/TestDAP_setBreakpoints.py +++ b/lldb/test/API/tools/lldb-dap/breakpoint/TestDAP_setBreakpoints.py @@ -2,7 +2,6 @@ Test lldb-dap setBreakpoints request """ - from dap_server import Source import shutil from lldbsuite.test.decorators import * @@ -58,7 +57,7 @@ def test_source_map(self): # breakpoint in main.cpp response = self.dap_server.request_setBreakpoints( - Source(new_main_path), [main_line] + Source.build(path=new_main_path), [main_line] ) breakpoints = response["body"]["breakpoints"] self.assertEqual(len(breakpoints), 1) @@ -70,7 +69,7 @@ def test_source_map(self): # 2nd breakpoint, which is from a dynamically loaded library response = self.dap_server.request_setBreakpoints( - Source(new_other_path), [other_line] + Source.build(path=new_other_path), [other_line] ) breakpoints = response["body"]["breakpoints"] breakpoint = breakpoints[0] @@ -85,7 +84,7 @@ def test_source_map(self): # 2nd breakpoint again, which should be valid at this point response = self.dap_server.request_setBreakpoints( - Source(new_other_path), [other_line] + Source.build(path=new_other_path), [other_line] ) breakpoints = response["body"]["breakpoints"] breakpoint = breakpoints[0] @@ -129,7 +128,9 @@ def test_set_and_clear(self): self.build_and_launch(program) # Set 3 breakpoints and verify that they got set correctly - response = self.dap_server.request_setBreakpoints(Source(self.main_path), lines) + response = self.dap_server.request_setBreakpoints( + Source.build(path=self.main_path), lines + ) line_to_id = {} breakpoints = response["body"]["breakpoints"] self.assertEqual( @@ -154,7 +155,9 @@ def test_set_and_clear(self): lines.remove(second_line) # Set 2 breakpoints and verify that the previous breakpoints that were # set above are still set. - response = self.dap_server.request_setBreakpoints(Source(self.main_path), lines) + response = self.dap_server.request_setBreakpoints( + Source.build(path=self.main_path), lines + ) breakpoints = response["body"]["breakpoints"] self.assertEqual( len(breakpoints), @@ -199,7 +202,9 @@ def test_set_and_clear(self): # Now clear all breakpoints for the source file by passing down an # empty lines array lines = [] - response = self.dap_server.request_setBreakpoints(Source(self.main_path), lines) + response = self.dap_server.request_setBreakpoints( + Source.build(path=self.main_path), lines + ) breakpoints = response["body"]["breakpoints"] self.assertEqual( len(breakpoints), @@ -219,7 +224,9 @@ def test_set_and_clear(self): # Now set a breakpoint again in the same source file and verify it # was added. lines = [second_line] - response = self.dap_server.request_setBreakpoints(Source(self.main_path), lines) + response = self.dap_server.request_setBreakpoints( + Source.build(path=self.main_path), lines + ) if response: breakpoints = response["body"]["breakpoints"] self.assertEqual( @@ -270,7 +277,9 @@ def test_clear_breakpoints_unset_breakpoints(self): self.build_and_launch(program) # Set one breakpoint and verify that it got set correctly. - response = self.dap_server.request_setBreakpoints(Source(self.main_path), lines) + response = self.dap_server.request_setBreakpoints( + Source.build(path=self.main_path), lines + ) line_to_id = {} breakpoints = response["body"]["breakpoints"] self.assertEqual( @@ -286,7 +295,9 @@ def test_clear_breakpoints_unset_breakpoints(self): # Now clear all breakpoints for the source file by not setting the # lines array. lines = None - response = self.dap_server.request_setBreakpoints(Source(self.main_path), lines) + response = self.dap_server.request_setBreakpoints( + Source.build(path=self.main_path), lines + ) breakpoints = response["body"]["breakpoints"] self.assertEqual(len(breakpoints), 0, "expect no source breakpoints") @@ -362,7 +373,7 @@ def test_column_breakpoints(self): # Set two breakpoints on the loop line at different columns. columns = [13, 39] response = self.dap_server.request_setBreakpoints( - Source(self.main_path), + Source.build(path=self.main_path), [loop_line, loop_line], list({"column": c} for c in columns), ) diff --git a/lldb/test/API/tools/lldb-dap/cancel/TestDAP_cancel.py b/lldb/test/API/tools/lldb-dap/cancel/TestDAP_cancel.py index 824ed8fe3bb97..e722fcea9283a 100644 --- a/lldb/test/API/tools/lldb-dap/cancel/TestDAP_cancel.py +++ b/lldb/test/API/tools/lldb-dap/cancel/TestDAP_cancel.py @@ -10,16 +10,14 @@ class TestDAP_cancel(lldbdap_testcase.DAPTestCaseBase): - def send_async_req(self, command: str, arguments={}) -> int: - seq = self.dap_server.sequence - self.dap_server.send_packet( + def send_async_req(self, command: str, arguments: dict = {}) -> int: + return self.dap_server.send_packet( { "type": "request", "command": command, "arguments": arguments, } ) - return seq def async_blocking_request(self, duration: float) -> int: """ @@ -54,18 +52,18 @@ def test_pending_request(self): pending_seq = self.async_blocking_request(duration=self.DEFAULT_TIMEOUT / 2) cancel_seq = self.async_cancel(requestId=pending_seq) - blocking_resp = self.dap_server.recv_packet(filter_type=["response"]) + blocking_resp = self.dap_server.receive_response(blocking_seq) self.assertEqual(blocking_resp["request_seq"], blocking_seq) self.assertEqual(blocking_resp["command"], "evaluate") self.assertEqual(blocking_resp["success"], True) - pending_resp = self.dap_server.recv_packet(filter_type=["response"]) + pending_resp = self.dap_server.receive_response(pending_seq) self.assertEqual(pending_resp["request_seq"], pending_seq) self.assertEqual(pending_resp["command"], "evaluate") self.assertEqual(pending_resp["success"], False) self.assertEqual(pending_resp["message"], "cancelled") - cancel_resp = self.dap_server.recv_packet(filter_type=["response"]) + cancel_resp = self.dap_server.receive_response(cancel_seq) self.assertEqual(cancel_resp["request_seq"], cancel_seq) self.assertEqual(cancel_resp["command"], "cancel") self.assertEqual(cancel_resp["success"], True) @@ -80,19 +78,16 @@ def test_inflight_request(self): blocking_seq = self.async_blocking_request(duration=self.DEFAULT_TIMEOUT / 2) # Wait for the sleep to start to cancel the inflight request. - self.collect_console( - timeout_secs=self.DEFAULT_TIMEOUT, - pattern="starting sleep", - ) + self.collect_console(pattern="starting sleep") cancel_seq = self.async_cancel(requestId=blocking_seq) - blocking_resp = self.dap_server.recv_packet(filter_type=["response"]) + blocking_resp = self.dap_server.receive_response(blocking_seq) self.assertEqual(blocking_resp["request_seq"], blocking_seq) self.assertEqual(blocking_resp["command"], "evaluate") self.assertEqual(blocking_resp["success"], False) self.assertEqual(blocking_resp["message"], "cancelled") - cancel_resp = self.dap_server.recv_packet(filter_type=["response"]) + cancel_resp = self.dap_server.receive_response(cancel_seq) self.assertEqual(cancel_resp["request_seq"], cancel_seq) self.assertEqual(cancel_resp["command"], "cancel") self.assertEqual(cancel_resp["success"], True) diff --git a/lldb/test/API/tools/lldb-dap/commands/TestDAP_commands.py b/lldb/test/API/tools/lldb-dap/commands/TestDAP_commands.py index ea6b2ea7f28ab..e61d2480ea4bb 100644 --- a/lldb/test/API/tools/lldb-dap/commands/TestDAP_commands.py +++ b/lldb/test/API/tools/lldb-dap/commands/TestDAP_commands.py @@ -1,8 +1,8 @@ -import os +""" +Test lldb-dap command hooks +""" -import dap_server import lldbdap_testcase -from lldbsuite.test import lldbtest, lldbutil from lldbsuite.test.decorators import * @@ -23,7 +23,7 @@ def test_command_directive_quiet_on_success(self): exitCommands=["?" + command_quiet, command_not_quiet], ) full_output = self.collect_console( - timeout_secs=1.0, + timeout=1.0, pattern=command_not_quiet, ) self.assertNotIn(command_quiet, full_output) @@ -51,7 +51,7 @@ def do_test_abort_on_error( expectFailure=True, ) full_output = self.collect_console( - timeout_secs=1.0, + timeout=1.0, pattern=command_abort_on_error, ) self.assertNotIn(command_quiet, full_output) @@ -81,9 +81,6 @@ def test_command_directive_abort_on_error_attach_commands(self): expectFailure=True, ) self.assertFalse(resp["success"], "expected 'attach' failure") - full_output = self.collect_console( - timeout_secs=1.0, - pattern=command_abort_on_error, - ) + full_output = self.collect_console(pattern=command_abort_on_error) self.assertNotIn(command_quiet, full_output) self.assertIn(command_abort_on_error, full_output) diff --git a/lldb/test/API/tools/lldb-dap/console/TestDAP_console.py b/lldb/test/API/tools/lldb-dap/console/TestDAP_console.py index 811843dfdf7af..ceddaeb50cd3b 100644 --- a/lldb/test/API/tools/lldb-dap/console/TestDAP_console.py +++ b/lldb/test/API/tools/lldb-dap/console/TestDAP_console.py @@ -139,9 +139,7 @@ def test_exit_status_message_sigterm(self): process.wait() # Get the console output - console_output = self.collect_console( - timeout_secs=10.0, pattern="exited with status" - ) + console_output = self.collect_console(pattern="exited with status") # Verify the exit status message is printed. self.assertRegex( @@ -156,9 +154,7 @@ def test_exit_status_message_ok(self): self.continue_to_exit() # Get the console output - console_output = self.collect_console( - timeout_secs=10.0, pattern="exited with status" - ) + console_output = self.collect_console(pattern="exited with status") # Verify the exit status message is printed. self.assertIn( @@ -177,9 +173,7 @@ def test_diagnositcs(self): f"target create --core {core}", context="repl" ) - diagnostics = self.collect_important( - timeout_secs=self.DEFAULT_TIMEOUT, pattern="minidump file" - ) + diagnostics = self.collect_important(pattern="minidump file") self.assertIn( "warning: unable to retrieve process ID from minidump file", diff --git a/lldb/test/API/tools/lldb-dap/instruction-breakpoint/TestDAP_instruction_breakpoint.py b/lldb/test/API/tools/lldb-dap/instruction-breakpoint/TestDAP_instruction_breakpoint.py index b8b266beaf182..8bb9ea2be5a9f 100644 --- a/lldb/test/API/tools/lldb-dap/instruction-breakpoint/TestDAP_instruction_breakpoint.py +++ b/lldb/test/API/tools/lldb-dap/instruction-breakpoint/TestDAP_instruction_breakpoint.py @@ -34,7 +34,7 @@ def instruction_breakpoint_test(self): # Set source breakpoint 1 response = self.dap_server.request_setBreakpoints( - Source(self.main_path), [main_line] + Source.build(path=self.main_path), [main_line] ) breakpoints = response["body"]["breakpoints"] self.assertEqual(len(breakpoints), 1) diff --git a/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py b/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py index b72b98de412b4..af5c62a8c4eb5 100644 --- a/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py +++ b/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py @@ -8,6 +8,9 @@ import lldbdap_testcase import dap_server +EXIT_FAILURE = 1 +EXIT_SUCCESS = 0 + class TestDAP_io(lldbdap_testcase.DAPTestCaseBase): def launch(self): @@ -41,40 +44,44 @@ def test_eof_immediately(self): """ process = self.launch() process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 0) + self.assertEqual(process.wait(timeout=5.0), EXIT_SUCCESS) def test_invalid_header(self): """ - lldb-dap handles invalid message headers. + lldb-dap returns a failure exit code when the input stream is closed + with a malformed request header. """ process = self.launch() - process.stdin.write(b"not the corret message header") + process.stdin.write(b"not the correct message header") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) def test_partial_header(self): """ - lldb-dap handles parital message headers. + lldb-dap returns a failure exit code when the input stream is closed + with an incomplete message header is in the message buffer. """ process = self.launch() process.stdin.write(b"Content-Length: ") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) def test_incorrect_content_length(self): """ - lldb-dap handles malformed content length headers. + lldb-dap returns a failure exit code when reading malformed content + length headers. """ process = self.launch() process.stdin.write(b"Content-Length: abc") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) def test_partial_content_length(self): """ - lldb-dap handles partial messages. + lldb-dap returns a failure exit code when the input stream is closed + with a partial message in the message buffer. """ process = self.launch() process.stdin.write(b"Content-Length: 10\r\n\r\n{") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) diff --git a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py index 2c0c5a583c58a..8bfceec1a636b 100644 --- a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py +++ b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py @@ -2,12 +2,9 @@ Test lldb-dap setBreakpoints request """ -import dap_server from lldbsuite.test.decorators import * from lldbsuite.test.lldbtest import * -from lldbsuite.test import lldbutil import lldbdap_testcase -import time import os import re @@ -208,7 +205,7 @@ def test_disableSTDIO(self): self.continue_to_exit() # Now get the STDOUT and verify our program argument is correct output = self.get_stdout() - self.assertEqual(output, None, "expect no program output") + self.assertEqual(output, "", "expect no program output") @skipIfWindows @skipIfLinux # shell argument expansion doesn't seem to work on Linux @@ -409,14 +406,14 @@ def test_commands(self): # Get output from the console. This should contain both the # "stopCommands" that were run after the first breakpoint was hit self.continue_to_breakpoints(breakpoint_ids) - output = self.get_console(timeout=self.DEFAULT_TIMEOUT) + output = self.get_console() self.verify_commands("stopCommands", output, stopCommands) # Continue again and hit the second breakpoint. # Get output from the console. This should contain both the # "stopCommands" that were run after the second breakpoint was hit self.continue_to_breakpoints(breakpoint_ids) - output = self.get_console(timeout=self.DEFAULT_TIMEOUT) + output = self.get_console() self.verify_commands("stopCommands", output, stopCommands) # Continue until the program exits @@ -424,10 +421,7 @@ def test_commands(self): # Get output from the console. This should contain both the # "exitCommands" that were run after the second breakpoint was hit # and the "terminateCommands" due to the debugging session ending - output = self.collect_console( - timeout_secs=1.0, - pattern=terminateCommands[0], - ) + output = self.collect_console(pattern=terminateCommands[0]) self.verify_commands("exitCommands", output, exitCommands) self.verify_commands("terminateCommands", output, terminateCommands) @@ -480,21 +474,21 @@ def test_extra_launch_commands(self): self.verify_commands("launchCommands", output, launchCommands) # Verify the "stopCommands" here self.continue_to_next_stop() - output = self.get_console(timeout=self.DEFAULT_TIMEOUT) + output = self.get_console() self.verify_commands("stopCommands", output, stopCommands) # Continue and hit the second breakpoint. # Get output from the console. This should contain both the # "stopCommands" that were run after the first breakpoint was hit self.continue_to_next_stop() - output = self.get_console(timeout=self.DEFAULT_TIMEOUT) + output = self.get_console() self.verify_commands("stopCommands", output, stopCommands) # Continue until the program exits self.continue_to_exit() # Get output from the console. This should contain both the # "exitCommands" that were run after the second breakpoint was hit - output = self.get_console(timeout=self.DEFAULT_TIMEOUT) + output = self.get_console() self.verify_commands("exitCommands", output, exitCommands) def test_failing_launch_commands(self): @@ -558,10 +552,7 @@ def test_terminate_commands(self): # Once it's disconnected the console should contain the # "terminateCommands" self.dap_server.request_disconnect(terminateDebuggee=True) - output = self.collect_console( - timeout_secs=1.0, - pattern=terminateCommands[0], - ) + output = self.collect_console(pattern=terminateCommands[0]) self.verify_commands("terminateCommands", output, terminateCommands) @skipIfWindows diff --git a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py index 64ed4154b035d..bb835af12f5ef 100644 --- a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py +++ b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py @@ -23,15 +23,15 @@ def test_module_event(self): self.continue_to_breakpoints(breakpoint_ids) # We're now stopped at breakpoint 1 before the dlopen. Flush all the module events. - event = self.dap_server.wait_for_event("module", 0.25) + event = self.dap_server.wait_for_event(["module"], 0.25) while event is not None: - event = self.dap_server.wait_for_event("module", 0.25) + event = self.dap_server.wait_for_event(["module"], 0.25) # Continue to the second breakpoint, before the dlclose. self.continue_to_breakpoints(breakpoint_ids) # Make sure we got a module event for libother. - event = self.dap_server.wait_for_event("module", 5) + event = self.dap_server.wait_for_event(["module"], 5) self.assertIsNotNone(event, "didn't get a module event") module_name = event["body"]["module"]["name"] module_id = event["body"]["module"]["id"] @@ -42,7 +42,7 @@ def test_module_event(self): self.continue_to_breakpoints(breakpoint_ids) # Make sure we got a module event for libother. - event = self.dap_server.wait_for_event("module", 5) + event = self.dap_server.wait_for_event(["module"], 5) self.assertIsNotNone(event, "didn't get a module event") reason = event["body"]["reason"] self.assertEqual(reason, "removed") @@ -56,7 +56,7 @@ def test_module_event(self): self.assertEqual(module_data["name"], "", "expects empty name.") # Make sure we do not send another event - event = self.dap_server.wait_for_event("module", 3) + event = self.dap_server.wait_for_event(["module"], 3) self.assertIsNone(event, "expects no events.") self.continue_to_exit() diff --git a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py index c9091df64f487..74743d9182ab4 100644 --- a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py +++ b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py @@ -1,11 +1,9 @@ """ -Test lldb-dap setBreakpoints request +Test lldb-dap module request """ -import dap_server from lldbsuite.test.decorators import * from lldbsuite.test.lldbtest import * -from lldbsuite.test import lldbutil import lldbdap_testcase import re @@ -55,7 +53,7 @@ def check_symbols_loaded_with_size(): if expect_debug_info_size: self.assertTrue( - self.waitUntil(check_symbols_loaded_with_size), + self.wait_until(check_symbols_loaded_with_size), "expect has debug info size", ) @@ -68,7 +66,7 @@ def check_symbols_loaded_with_size(): # Collect all the module names we saw as events. module_new_names = [] module_changed_names = [] - module_event = self.dap_server.wait_for_event("module", 1) + module_event = self.dap_server.wait_for_event(["module"], 1) while module_event is not None: reason = module_event["body"]["reason"] if reason == "new": @@ -76,7 +74,7 @@ def check_symbols_loaded_with_size(): elif reason == "changed": module_changed_names.append(module_event["body"]["module"]["name"]) - module_event = self.dap_server.wait_for_event("module", 1) + module_event = self.dap_server.wait_for_event(["module"], 1) # Make sure we got an event for every active module. self.assertNotEqual(len(module_new_names), 0) diff --git a/lldb/test/API/tools/lldb-dap/moduleSymbols/Makefile b/lldb/test/API/tools/lldb-dap/moduleSymbols/Makefile new file mode 100644 index 0000000000000..10495940055b6 --- /dev/null +++ b/lldb/test/API/tools/lldb-dap/moduleSymbols/Makefile @@ -0,0 +1,3 @@ +C_SOURCES := main.c + +include Makefile.rules diff --git a/lldb/test/API/tools/lldb-dap/moduleSymbols/TestDAP_moduleSymbols.py b/lldb/test/API/tools/lldb-dap/moduleSymbols/TestDAP_moduleSymbols.py new file mode 100644 index 0000000000000..2336b9f2a5a1a --- /dev/null +++ b/lldb/test/API/tools/lldb-dap/moduleSymbols/TestDAP_moduleSymbols.py @@ -0,0 +1,40 @@ +""" +Test lldb-dap moduleSymbols request +""" + +import lldbdap_testcase +from lldbsuite.test.decorators import * + + +class TestDAP_moduleSymbols(lldbdap_testcase.DAPTestCaseBase): + # On windows LLDB doesn't recognize symbols in a.out. + @skipIfWindows + def test_moduleSymbols(self): + """ + Test that the moduleSymbols request returns correct symbols from the module. + """ + program = self.getBuildArtifact("a.out") + self.build_and_launch(program) + + symbol_names = [] + i = 0 + while True: + next_symbol = self.dap_server.request_moduleSymbols( + moduleName="a.out", startIndex=i, count=1 + ) + self.assertIn("symbols", next_symbol["body"]) + result_symbols = next_symbol["body"]["symbols"] + self.assertLessEqual(len(result_symbols), 1) + if len(result_symbols) == 0: + break + + self.assertIn("name", result_symbols[0]) + symbol_names.append(result_symbols[0]["name"]) + i += 1 + if i >= 1000: + break + + self.assertGreater(len(symbol_names), 0) + self.assertIn("main", symbol_names) + self.assertIn("func1", symbol_names) + self.assertIn("func2", symbol_names) diff --git a/lldb/test/API/tools/lldb-dap/moduleSymbols/main.c b/lldb/test/API/tools/lldb-dap/moduleSymbols/main.c new file mode 100644 index 0000000000000..b038b10480b80 --- /dev/null +++ b/lldb/test/API/tools/lldb-dap/moduleSymbols/main.c @@ -0,0 +1,9 @@ +int func1() { return 42; } + +int func2() { return 84; } + +int main() { + func1(); + func2(); + return 0; +} diff --git a/lldb/test/API/tools/lldb-dap/output/TestDAP_output.py b/lldb/test/API/tools/lldb-dap/output/TestDAP_output.py index 0425b55a5e552..fe978a9a73351 100644 --- a/lldb/test/API/tools/lldb-dap/output/TestDAP_output.py +++ b/lldb/test/API/tools/lldb-dap/output/TestDAP_output.py @@ -29,7 +29,7 @@ def test_output(self): self.continue_to_breakpoints(breakpoint_ids) # Ensure partial messages are still sent. - output = self.collect_stdout(timeout_secs=1.0, pattern="abcdef") + output = self.collect_stdout(timeout=1.0, pattern="abcdef") self.assertTrue(output and len(output) > 0, "expect program stdout") self.continue_to_exit() @@ -37,14 +37,14 @@ def test_output(self): # Disconnecting from the server to ensure any pending IO is flushed. self.dap_server.request_disconnect() - output += self.get_stdout(timeout=self.DEFAULT_TIMEOUT) + output += self.get_stdout() self.assertTrue(output and len(output) > 0, "expect program stdout") self.assertIn( "abcdefghi\r\nhello world\r\nfinally\0\0", output, "full stdout not found in: " + repr(output), ) - console = self.get_console(timeout=self.DEFAULT_TIMEOUT) + console = self.get_console() self.assertTrue(console and len(console) > 0, "expect dap messages") self.assertIn( "out\0\0\r\nerr\0\0\r\n", console, f"full console message not found" diff --git a/lldb/test/API/tools/lldb-dap/progress/TestDAP_Progress.py b/lldb/test/API/tools/lldb-dap/progress/TestDAP_Progress.py index b47d52968f8a1..3f57dfb66024d 100755 --- a/lldb/test/API/tools/lldb-dap/progress/TestDAP_Progress.py +++ b/lldb/test/API/tools/lldb-dap/progress/TestDAP_Progress.py @@ -21,7 +21,7 @@ def verify_progress_events( expected_not_in_message=None, only_verify_first_update=False, ): - self.dap_server.wait_for_event("progressEnd") + self.dap_server.wait_for_event(["progressEnd"]) self.assertTrue(len(self.dap_server.progress_events) > 0) start_found = False update_found = False diff --git a/lldb/tools/CMakeLists.txt b/lldb/tools/CMakeLists.txt index a15082fe0b48b..2b68343ab8a59 100644 --- a/lldb/tools/CMakeLists.txt +++ b/lldb/tools/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(lldb-fuzzer EXCLUDE_FROM_ALL) add_lldb_tool_subdirectory(lldb-instr) add_lldb_tool_subdirectory(lldb-dap) +add_lldb_tool_subdirectory(lldb-mcp) if (LLDB_BUILD_LLDBRPC) add_lldb_tool_subdirectory(lldb-rpc-gen) endif() diff --git a/lldb/tools/lldb-dap/CMakeLists.txt b/lldb/tools/lldb-dap/CMakeLists.txt index 5e0ad53b82f89..7db334ca56bcf 100644 --- a/lldb/tools/lldb-dap/CMakeLists.txt +++ b/lldb/tools/lldb-dap/CMakeLists.txt @@ -45,6 +45,7 @@ add_lldb_library(lldbDAP Handler/LaunchRequestHandler.cpp Handler/LocationsRequestHandler.cpp Handler/ModulesRequestHandler.cpp + Handler/ModuleSymbolsRequestHandler.cpp Handler/NextRequestHandler.cpp Handler/PauseRequestHandler.cpp Handler/ReadMemoryRequestHandler.cpp diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index 849712f724c69..b1ad38d983893 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -23,13 +23,14 @@ #include "Transport.h" #include "lldb/API/SBBreakpoint.h" #include "lldb/API/SBCommandInterpreter.h" -#include "lldb/API/SBCommandReturnObject.h" #include "lldb/API/SBEvent.h" #include "lldb/API/SBLanguageRuntime.h" #include "lldb/API/SBListener.h" #include "lldb/API/SBProcess.h" #include "lldb/API/SBStream.h" -#include "lldb/Utility/IOObject.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-defines.h" #include "lldb/lldb-enumerations.h" @@ -52,7 +53,7 @@ #include #include #include -#include +#include #include #include #include @@ -120,11 +121,12 @@ static std::string capitalize(llvm::StringRef str) { llvm::StringRef DAP::debug_adapter_path = ""; DAP::DAP(Log *log, const ReplMode default_repl_mode, - std::vector pre_init_commands, Transport &transport) + std::vector pre_init_commands, + llvm::StringRef client_name, DAPTransport &transport, MainLoop &loop) : log(log), transport(transport), broadcaster("lldb-dap"), progress_event_reporter( [&](const ProgressEvent &event) { SendJSON(event.ToJSON()); }), - repl_mode(default_repl_mode) { + repl_mode(default_repl_mode), m_client_name(client_name), m_loop(loop) { configuration.preInitCommands = std::move(pre_init_commands); RegisterRequests(); } @@ -257,36 +259,49 @@ void DAP::SendJSON(const llvm::json::Value &json) { llvm::json::Path::Root root; if (!fromJSON(json, message, root)) { DAP_LOG_ERROR(log, root.getError(), "({1}) encoding failed: {0}", - transport.GetClientName()); + m_client_name); return; } Send(message); } void DAP::Send(const Message &message) { - // FIXME: After all the requests have migrated from LegacyRequestHandler > - // RequestHandler<> this should be handled in RequestHandler<>::operator(). - if (auto *resp = std::get_if(&message); - resp && debugger.InterruptRequested()) { - // Clear the interrupt request. - debugger.CancelInterruptRequest(); - - // If the debugger was interrupted, convert this response into a 'cancelled' - // response because we might have a partial result. - Response cancelled{/*request_seq=*/resp->request_seq, - /*command=*/resp->command, - /*success=*/false, - /*message=*/eResponseMessageCancelled, - /*body=*/std::nullopt}; - if (llvm::Error err = transport.Write(cancelled)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); + if (const protocol::Event *event = std::get_if(&message)) { + if (llvm::Error err = transport.Send(*event)) + DAP_LOG_ERROR(log, std::move(err), "({0}) sending event failed", + m_client_name); return; } - if (llvm::Error err = transport.Write(message)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); + if (const Request *req = std::get_if(&message)) { + if (llvm::Error err = transport.Send(*req)) + DAP_LOG_ERROR(log, std::move(err), "({0}) sending request failed", + m_client_name); + return; + } + + if (const Response *resp = std::get_if(&message)) { + // FIXME: After all the requests have migrated from LegacyRequestHandler > + // RequestHandler<> this should be handled in RequestHandler<>::operator(). + // If the debugger was interrupted, convert this response into a + // 'cancelled' response because we might have a partial result. + llvm::Error err = + (debugger.InterruptRequested()) + ? transport.Send({/*request_seq=*/resp->request_seq, + /*command=*/resp->command, + /*success=*/false, + /*message=*/eResponseMessageCancelled, + /*body=*/std::nullopt}) + : transport.Send(*resp); + if (err) { + DAP_LOG_ERROR(log, std::move(err), "({0}) sending response failed", + m_client_name); + return; + } + return; + } + + llvm_unreachable("Unexpected message type"); } // "OutputEvent": { @@ -551,6 +566,9 @@ lldb::SBThread DAP::GetLLDBThread(const llvm::json::Object &arguments) { } lldb::SBFrame DAP::GetLLDBFrame(uint64_t frame_id) { + if (frame_id == LLDB_DAP_INVALID_FRAME_ID) + return lldb::SBFrame(); + lldb::SBProcess process = target.GetProcess(); // Upper 32 bits is the thread index ID lldb::SBThread thread = @@ -560,8 +578,8 @@ lldb::SBFrame DAP::GetLLDBFrame(uint64_t frame_id) { } lldb::SBFrame DAP::GetLLDBFrame(const llvm::json::Object &arguments) { - const auto frame_id = - GetInteger(arguments, "frameId").value_or(UINT64_MAX); + const auto frame_id = GetInteger(arguments, "frameId") + .value_or(LLDB_DAP_INVALID_FRAME_ID); return GetLLDBFrame(frame_id); } @@ -754,7 +772,6 @@ void DAP::RunTerminateCommands() { } lldb::SBTarget DAP::CreateTarget(lldb::SBError &error) { - // Grab the name of the program we need to debug and create a target using // the given program as an argument. Executable file can be a source of target // architecture and platform, if they differ from the host. Setting exe path // in launch info is useless because Target.Launch() will not change @@ -794,7 +811,7 @@ void DAP::SetTarget(const lldb::SBTarget target) { bool DAP::HandleObject(const Message &M) { TelemetryDispatcher dispatcher(&debugger); - dispatcher.Set("client_name", transport.GetClientName().str()); + dispatcher.Set("client_name", m_client_name.str()); if (const auto *req = std::get_if(&M)) { { std::lock_guard guard(m_active_request_mutex); @@ -820,8 +837,8 @@ bool DAP::HandleObject(const Message &M) { dispatcher.Set("error", llvm::Twine("unhandled-command:" + req->command).str()); - DAP_LOG(log, "({0}) error: unhandled command '{1}'", - transport.GetClientName(), req->command); + DAP_LOG(log, "({0}) error: unhandled command '{1}'", m_client_name, + req->command); return false; // Fail } @@ -917,9 +934,7 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { } SendTerminatedEvent(); - - disconnecting = true; - + TerminateLoop(); return ToError(error); } @@ -935,91 +950,121 @@ void DAP::ClearCancelRequest(const CancelArguments &args) { } template -static std::optional getArgumentsIfRequest(const Message &pm, +static std::optional getArgumentsIfRequest(const Request &req, llvm::StringLiteral command) { - auto *const req = std::get_if(&pm); - if (!req || req->command != command) + if (req.command != command) return std::nullopt; T args; llvm::json::Path::Root root; - if (!fromJSON(req->arguments, args, root)) + if (!fromJSON(req.arguments, args, root)) return std::nullopt; return args; } -llvm::Error DAP::Loop() { - // Can't use \a std::future because it doesn't compile on - // Windows. - std::future queue_reader = - std::async(std::launch::async, [&]() -> lldb::SBError { - llvm::set_thread_name(transport.GetClientName() + ".transport_handler"); - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. - disconnecting = true; - m_queue_cv.notify_all(); - }); - - while (!disconnecting) { - llvm::Expected next = - transport.Read(std::chrono::seconds(1)); - if (next.errorIsA()) { - consumeError(next.takeError()); - break; - } +void DAP::Received(const protocol::Event &event) { + // no-op, no supported events from the client to the server as of DAP v1.68. +} - // If the read timed out, continue to check if we should disconnect. - if (next.errorIsA()) { - consumeError(next.takeError()); - continue; - } +void DAP::Received(const protocol::Request &request) { + if (request.command == "disconnect") + m_disconnecting = true; - if (llvm::Error err = next.takeError()) { - lldb::SBError errWrapper; - errWrapper.SetErrorString(llvm::toString(std::move(err)).c_str()); - return errWrapper; - } + const std::optional cancel_args = + getArgumentsIfRequest(request, "cancel"); + if (cancel_args) { + { + std::lock_guard guard(m_cancelled_requests_mutex); + if (cancel_args->requestId) + m_cancelled_requests.insert(*cancel_args->requestId); + } - if (const protocol::Request *req = - std::get_if(&*next); - req && req->command == "disconnect") - disconnecting = true; - - const std::optional cancel_args = - getArgumentsIfRequest(*next, "cancel"); - if (cancel_args) { - { - std::lock_guard guard(m_cancelled_requests_mutex); - if (cancel_args->requestId) - m_cancelled_requests.insert(*cancel_args->requestId); - } + // If a cancel is requested for the active request, make a best + // effort attempt to interrupt. + std::lock_guard guard(m_active_request_mutex); + if (m_active_request && cancel_args->requestId == m_active_request->seq) { + DAP_LOG(log, "({0}) interrupting inflight request (command={1} seq={2})", + m_client_name, m_active_request->command, m_active_request->seq); + debugger.RequestInterrupt(); + } + } - // If a cancel is requested for the active request, make a best - // effort attempt to interrupt. - std::lock_guard guard(m_active_request_mutex); - if (m_active_request && - cancel_args->requestId == m_active_request->seq) { - DAP_LOG( - log, - "({0}) interrupting inflight request (command={1} seq={2})", - transport.GetClientName(), m_active_request->command, - m_active_request->seq); - debugger.RequestInterrupt(); - } - } + std::lock_guard guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + request.command, request.seq); + m_queue.push_back(request); + m_queue_cv.notify_one(); +} - { - std::lock_guard guard(m_queue_mutex); - m_queue.push_back(std::move(*next)); - } - m_queue_cv.notify_one(); - } +void DAP::Received(const protocol::Response &response) { + std::lock_guard guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + response.command, response.request_seq); + m_queue.push_back(response); + m_queue_cv.notify_one(); +} + +void DAP::OnError(llvm::Error error) { + DAP_LOG_ERROR(log, std::move(error), "({1}) received error: {0}", + m_client_name); + TerminateLoop(/*failed=*/true); +} + +void DAP::OnClosed() { + DAP_LOG(log, "({0}) received EOF", m_client_name); + TerminateLoop(); +} - return lldb::SBError(); - }); +void DAP::TerminateLoop(bool failed) { + std::lock_guard guard(m_queue_mutex); + if (m_disconnecting) + return; // Already disconnecting. - auto cleanup = llvm::make_scope_exit([&]() { + m_error_occurred = failed; + m_disconnecting = true; + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); +} + +void DAP::TransportHandler() { + auto scope_guard = llvm::make_scope_exit([this] { + std::lock_guard guard(m_queue_mutex); + // Ensure we're marked as disconnecting when the reader exits. + m_disconnecting = true; + m_queue_cv.notify_all(); + }); + + auto handle = transport.RegisterMessageHandler(m_loop, *this); + if (!handle) { + DAP_LOG_ERROR(log, handle.takeError(), + "({1}) registering message handler failed: {0}", + m_client_name); + std::lock_guard guard(m_queue_mutex); + m_error_occurred = true; + return; + } + + if (Status status = m_loop.Run(); status.Fail()) { + DAP_LOG_ERROR(log, status.takeError(), "({1}) MainLoop run failed: {0}", + m_client_name); + std::lock_guard guard(m_queue_mutex); + m_error_occurred = true; + return; + } +} + +llvm::Error DAP::Loop() { + { + // Reset disconnect flag once we start the loop. + std::lock_guard guard(m_queue_mutex); + m_disconnecting = false; + } + + auto thread = std::thread(std::bind(&DAP::TransportHandler, this)); + + auto cleanup = llvm::make_scope_exit([this]() { + // FIXME: Merge these into the MainLoop handler. out.Stop(); err.Stop(); StopEventHandlers(); @@ -1027,9 +1072,9 @@ llvm::Error DAP::Loop() { while (true) { std::unique_lock lock(m_queue_mutex); - m_queue_cv.wait(lock, [&] { return disconnecting || !m_queue.empty(); }); + m_queue_cv.wait(lock, [&] { return m_disconnecting || !m_queue.empty(); }); - if (disconnecting && m_queue.empty()) + if (m_disconnecting && m_queue.empty()) break; Message next = m_queue.front(); @@ -1043,7 +1088,15 @@ llvm::Error DAP::Loop() { "unhandled packet"); } - return ToError(queue_reader.get()); + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + thread.join(); + + if (m_error_occurred) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "DAP Loop terminated due to an internal " + "error, see DAP Logs for more information."); + return llvm::Error::success(); } lldb::SBError DAP::WaitForProcessToStop(std::chrono::seconds seconds) { @@ -1208,6 +1261,27 @@ protocol::Capabilities DAP::GetCapabilities() { return capabilities; } +protocol::Capabilities DAP::GetCustomCapabilities() { + protocol::Capabilities capabilities; + + // Add all custom capabilities here. + const llvm::DenseSet all_custom_features = { + protocol::eAdapterFeatureSupportsModuleSymbolsRequest, + }; + + for (auto &kv : request_handlers) { + llvm::SmallDenseSet features = + kv.second->GetSupportedFeatures(); + + for (auto &feature : features) { + if (all_custom_features.contains(feature)) + capabilities.supportedFeatures.insert(feature); + } + } + + return capabilities; +} + void DAP::StartEventThread() { event_thread = std::thread(&DAP::EventThread, this); } @@ -1282,7 +1356,7 @@ void DAP::ProgressEventThread() { // them prevent multiple threads from writing simultaneously so no locking // is required. void DAP::EventThread() { - llvm::set_thread_name(transport.GetClientName() + ".event_handler"); + llvm::set_thread_name("lldb.DAP.client." + m_client_name + ".event_handler"); lldb::SBEvent event; lldb::SBListener listener = debugger.GetListener(); broadcaster.AddListener(listener, eBroadcastBitStopEventThread); @@ -1314,7 +1388,7 @@ void DAP::EventThread() { if (llvm::Error err = SendThreadStoppedEvent(*this)) DAP_LOG_ERROR(log, std::move(err), "({1}) reporting thread stopped: {0}", - transport.GetClientName()); + m_client_name); } break; case lldb::eStateRunning: @@ -1564,6 +1638,7 @@ void DAP::RegisterRequests() { // Custom requests RegisterRequest(); RegisterRequest(); + RegisterRequest(); // Testing requests RegisterRequest(); diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index af4aabaafaae8..04f70f76a09cd 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -31,6 +31,8 @@ #include "lldb/API/SBMutex.h" #include "lldb/API/SBTarget.h" #include "lldb/API/SBThread.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Utility/Status.h" #include "lldb/lldb-types.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -76,12 +78,16 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; -struct DAP { +using DAPTransport = + lldb_private::Transport; + +struct DAP final : private DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; Log *log; - Transport &transport; + DAPTransport &transport; lldb::SBFile in; OutputRedirector out; OutputRedirector err; @@ -112,7 +118,6 @@ struct DAP { /// The focused thread for this DAP session. lldb::tid_t focus_tid = LLDB_INVALID_THREAD_ID; - bool disconnecting = false; llvm::once_flag terminated_event_flag; bool stop_at_entry = false; bool is_attach = false; @@ -175,8 +180,11 @@ struct DAP { /// allocated. /// \param[in] transport /// Transport for this debug session. + /// \param[in] loop + /// Main loop associated with this instance. DAP(Log *log, const ReplMode default_repl_mode, - std::vector pre_init_commands, Transport &transport); + std::vector pre_init_commands, llvm::StringRef client_name, + DAPTransport &transport, lldb_private::MainLoop &loop); ~DAP(); @@ -315,7 +323,7 @@ struct DAP { lldb::SBTarget CreateTarget(lldb::SBError &error); /// Set given target object as a current target for lldb-dap and start - /// listeing for its breakpoint events. + /// listening for its breakpoint events. void SetTarget(const lldb::SBTarget target); bool HandleObject(const protocol::Message &M); @@ -359,6 +367,9 @@ struct DAP { /// The set of capabilities supported by this adapter. protocol::Capabilities GetCapabilities(); + /// The set of custom capabilities supported by this adapter. + protocol::Capabilities GetCustomCapabilities(); + /// Debuggee will continue from stopped state. void WillContinue() { variables.Clear(); } @@ -418,12 +429,21 @@ struct DAP { const std::optional> &breakpoints); + void Received(const protocol::Event &) override; + void Received(const protocol::Request &) override; + void Received(const protocol::Response &) override; + void OnError(llvm::Error) override; + void OnClosed() override; + private: std::vector SetSourceBreakpoints( const protocol::Source &source, const std::optional> &breakpoints, SourceBreakpointMap &existing_breakpoints); + void TransportHandler(); + void TerminateLoop(bool failed = false); + /// Registration of request handler. /// @{ void RegisterRequests(); @@ -442,6 +462,8 @@ struct DAP { std::thread progress_event_thread; /// @} + const llvm::StringRef m_client_name; + /// List of addresses mapped by sourceReference. std::vector m_source_references; std::mutex m_source_references_mutex; @@ -450,6 +472,11 @@ struct DAP { std::deque m_queue; std::mutex m_queue_mutex; std::condition_variable m_queue_cv; + bool m_disconnecting = false; + bool m_error_occurred = false; + + // Loop for managing reading from the client. + lldb_private::MainLoop &m_loop; std::mutex m_cancelled_requests_mutex; llvm::SmallSet m_cancelled_requests; diff --git a/lldb/tools/lldb-dap/EventHelper.cpp b/lldb/tools/lldb-dap/EventHelper.cpp index 364cc7ab4ef8c..bfb05a387d04d 100644 --- a/lldb/tools/lldb-dap/EventHelper.cpp +++ b/lldb/tools/lldb-dap/EventHelper.cpp @@ -38,25 +38,37 @@ static void SendThreadExitedEvent(DAP &dap, lldb::tid_t tid) { dap.SendJSON(llvm::json::Value(std::move(event))); } -void SendTargetBasedCapabilities(DAP &dap) { +/// Get capabilities based on the configured target. +static llvm::DenseSet GetTargetBasedCapabilities(DAP &dap) { + llvm::DenseSet capabilities; if (!dap.target.IsValid()) - return; - - protocol::CapabilitiesEventBody body; + return capabilities; const llvm::StringRef target_triple = dap.target.GetTriple(); if (target_triple.starts_with("x86")) - body.capabilities.supportedFeatures.insert( - protocol::eAdapterFeatureStepInTargetsRequest); + capabilities.insert(protocol::eAdapterFeatureStepInTargetsRequest); // We only support restarting launch requests not attach requests. if (dap.last_launch_request) - body.capabilities.supportedFeatures.insert( - protocol::eAdapterFeatureRestartRequest); + capabilities.insert(protocol::eAdapterFeatureRestartRequest); + + return capabilities; +} + +void SendExtraCapabilities(DAP &dap) { + protocol::Capabilities capabilities = dap.GetCustomCapabilities(); + llvm::DenseSet target_capabilities = + GetTargetBasedCapabilities(dap); + + capabilities.supportedFeatures.insert(target_capabilities.begin(), + target_capabilities.end()); + + protocol::CapabilitiesEventBody body; + body.capabilities = std::move(capabilities); // Only notify the client if supportedFeatures changed. if (!body.capabilities.supportedFeatures.empty()) - dap.Send(protocol::Event{"capabilities", body}); + dap.Send(protocol::Event{"capabilities", std::move(body)}); } // "ProcessEvent": { diff --git a/lldb/tools/lldb-dap/EventHelper.h b/lldb/tools/lldb-dap/EventHelper.h index 72ad5308a2b0c..592c1b81c46af 100644 --- a/lldb/tools/lldb-dap/EventHelper.h +++ b/lldb/tools/lldb-dap/EventHelper.h @@ -17,8 +17,8 @@ struct DAP; enum LaunchMethod { Launch, Attach, AttachForSuspendedLaunch }; -/// Update capabilities based on the configured target. -void SendTargetBasedCapabilities(DAP &dap); +/// Sends target based capabilities and lldb-dap custom capabilities. +void SendExtraCapabilities(DAP &dap); void SendProcessEvent(DAP &dap, LaunchMethod launch_method); diff --git a/lldb/tools/lldb-dap/Handler/CompletionsHandler.cpp b/lldb/tools/lldb-dap/Handler/CompletionsHandler.cpp index c72fc5686cd5b..de9a15dcb73f4 100644 --- a/lldb/tools/lldb-dap/Handler/CompletionsHandler.cpp +++ b/lldb/tools/lldb-dap/Handler/CompletionsHandler.cpp @@ -8,156 +8,46 @@ #include "DAP.h" #include "JSONUtils.h" +#include "Protocol/ProtocolRequests.h" +#include "Protocol/ProtocolTypes.h" #include "RequestHandler.h" #include "lldb/API/SBStringList.h" -namespace lldb_dap { +using namespace llvm; +using namespace lldb_dap; +using namespace lldb_dap::protocol; -// "CompletionsRequest": { -// "allOf": [ { "$ref": "#/definitions/Request" }, { -// "type": "object", -// "description": "Returns a list of possible completions for a given caret -// position and text.\nThe CompletionsRequest may only be called if the -// 'supportsCompletionsRequest' capability exists and is true.", -// "properties": { -// "command": { -// "type": "string", -// "enum": [ "completions" ] -// }, -// "arguments": { -// "$ref": "#/definitions/CompletionsArguments" -// } -// }, -// "required": [ "command", "arguments" ] -// }] -// }, -// "CompletionsArguments": { -// "type": "object", -// "description": "Arguments for 'completions' request.", -// "properties": { -// "frameId": { -// "type": "integer", -// "description": "Returns completions in the scope of this stack frame. -// If not specified, the completions are returned for the global scope." -// }, -// "text": { -// "type": "string", -// "description": "One or more source lines. Typically this is the text a -// user has typed into the debug console before he asked for completion." -// }, -// "column": { -// "type": "integer", -// "description": "The character position for which to determine the -// completion proposals." -// }, -// "line": { -// "type": "integer", -// "description": "An optional line for which to determine the completion -// proposals. If missing the first line of the text is assumed." -// } -// }, -// "required": [ "text", "column" ] -// }, -// "CompletionsResponse": { -// "allOf": [ { "$ref": "#/definitions/Response" }, { -// "type": "object", -// "description": "Response to 'completions' request.", -// "properties": { -// "body": { -// "type": "object", -// "properties": { -// "targets": { -// "type": "array", -// "items": { -// "$ref": "#/definitions/CompletionItem" -// }, -// "description": "The possible completions for ." -// } -// }, -// "required": [ "targets" ] -// } -// }, -// "required": [ "body" ] -// }] -// }, -// "CompletionItem": { -// "type": "object", -// "description": "CompletionItems are the suggestions returned from the -// CompletionsRequest.", "properties": { -// "label": { -// "type": "string", -// "description": "The label of this completion item. By default this is -// also the text that is inserted when selecting this completion." -// }, -// "text": { -// "type": "string", -// "description": "If text is not falsy then it is inserted instead of the -// label." -// }, -// "sortText": { -// "type": "string", -// "description": "A string that should be used when comparing this item -// with other items. When `falsy` the label is used." -// }, -// "type": { -// "$ref": "#/definitions/CompletionItemType", -// "description": "The item's type. Typically the client uses this -// information to render the item in the UI with an icon." -// }, -// "start": { -// "type": "integer", -// "description": "This value determines the location (in the -// CompletionsRequest's 'text' attribute) where the completion text is -// added.\nIf missing the text is added at the location specified by the -// CompletionsRequest's 'column' attribute." -// }, -// "length": { -// "type": "integer", -// "description": "This value determines how many characters are -// overwritten by the completion text.\nIf missing the value 0 is assumed -// which results in the completion text being inserted." -// } -// }, -// "required": [ "label" ] -// }, -// "CompletionItemType": { -// "type": "string", -// "description": "Some predefined types for the CompletionItem. Please note -// that not all clients have specific icons for all of them.", "enum": [ -// "method", "function", "constructor", "field", "variable", "class", -// "interface", "module", "property", "unit", "value", "enum", "keyword", -// "snippet", "text", "color", "file", "reference", "customcolor" ] -// } -void CompletionsRequestHandler::operator()( - const llvm::json::Object &request) const { - llvm::json::Object response; - FillResponse(request, response); - llvm::json::Object body; - const auto *arguments = request.getObject("arguments"); +namespace lldb_dap { +/// Returns a list of possible completions for a given caret position and text. +/// +/// Clients should only call this request if the corresponding capability +/// `supportsCompletionsRequest` is true. +Expected +CompletionsRequestHandler::Run(const CompletionsArguments &args) const { // If we have a frame, try to set the context for variable completions. - lldb::SBFrame frame = dap.GetLLDBFrame(*arguments); + lldb::SBFrame frame = dap.GetLLDBFrame(args.frameId); if (frame.IsValid()) { frame.GetThread().GetProcess().SetSelectedThread(frame.GetThread()); frame.GetThread().SetSelectedFrame(frame.GetFrameID()); } - std::string text = GetString(arguments, "text").value_or("").str(); - auto original_column = - GetInteger(arguments, "column").value_or(text.size()); - auto original_line = GetInteger(arguments, "line").value_or(1); + std::string text = args.text; + auto original_column = args.column; + auto original_line = args.line; auto offset = original_column - 1; if (original_line > 1) { - llvm::SmallVector<::llvm::StringRef, 2> lines; - llvm::StringRef(text).split(lines, '\n'); + SmallVector lines; + StringRef(text).split(lines, '\n'); for (int i = 0; i < original_line - 1; i++) { offset += lines[i].size(); } } - llvm::json::Array targets; + + std::vector targets; bool had_escape_prefix = - llvm::StringRef(text).starts_with(dap.configuration.commandEscapePrefix); + StringRef(text).starts_with(dap.configuration.commandEscapePrefix); ReplMode completion_mode = dap.DetectReplMode(frame, text, true); // Handle the offset change introduced by stripping out the @@ -165,10 +55,7 @@ void CompletionsRequestHandler::operator()( if (had_escape_prefix) { if (offset < static_cast(dap.configuration.commandEscapePrefix.size())) { - body.try_emplace("targets", std::move(targets)); - response.try_emplace("body", std::move(body)); - dap.SendJSON(llvm::json::Value(std::move(response))); - return; + return CompletionsResponseBody{std::move(targets)}; } offset -= dap.configuration.commandEscapePrefix.size(); } @@ -198,27 +85,25 @@ void CompletionsRequestHandler::operator()( std::string match = matches.GetStringAtIndex(i); std::string description = descriptions.GetStringAtIndex(i); - llvm::json::Object item; - llvm::StringRef match_ref = match; - for (llvm::StringRef commit_point : {".", "->"}) { + CompletionItem item; + StringRef match_ref = match; + for (StringRef commit_point : {".", "->"}) { if (match_ref.contains(commit_point)) { match_ref = match_ref.rsplit(commit_point).second; } } - EmplaceSafeString(item, "text", match_ref); + item.text = match_ref; if (description.empty()) - EmplaceSafeString(item, "label", match); + item.label = match; else - EmplaceSafeString(item, "label", match + " -- " + description); + item.label = match + " -- " + description; targets.emplace_back(std::move(item)); } } - body.try_emplace("targets", std::move(targets)); - response.try_emplace("body", std::move(body)); - dap.SendJSON(llvm::json::Value(std::move(response))); + return CompletionsResponseBody{std::move(targets)}; } } // namespace lldb_dap diff --git a/lldb/tools/lldb-dap/Handler/ConfigurationDoneRequestHandler.cpp b/lldb/tools/lldb-dap/Handler/ConfigurationDoneRequestHandler.cpp index e7735a705d0aa..1bfe7b7f6ef5c 100644 --- a/lldb/tools/lldb-dap/Handler/ConfigurationDoneRequestHandler.cpp +++ b/lldb/tools/lldb-dap/Handler/ConfigurationDoneRequestHandler.cpp @@ -9,6 +9,7 @@ #include "DAP.h" #include "EventHelper.h" #include "LLDBUtils.h" +#include "Protocol/ProtocolEvents.h" #include "Protocol/ProtocolRequests.h" #include "ProtocolUtils.h" #include "RequestHandler.h" @@ -44,7 +45,10 @@ ConfigurationDoneRequestHandler::Run(const ConfigurationDoneArguments &) const { // Waiting until 'configurationDone' to send target based capabilities in case // the launch or attach scripts adjust the target. The initial dummy target // may have different capabilities than the final target. - SendTargetBasedCapabilities(dap); + + /// Also send here custom capabilities to the client, which is consumed by the + /// lldb-dap specific editor extension. + SendExtraCapabilities(dap); // Clients can request a baseline of currently existing threads after // we acknowledge the configurationDone request. diff --git a/lldb/tools/lldb-dap/Handler/DataBreakpointInfoRequestHandler.cpp b/lldb/tools/lldb-dap/Handler/DataBreakpointInfoRequestHandler.cpp index 8cb25d0603449..87b93fc999ecd 100644 --- a/lldb/tools/lldb-dap/Handler/DataBreakpointInfoRequestHandler.cpp +++ b/lldb/tools/lldb-dap/Handler/DataBreakpointInfoRequestHandler.cpp @@ -23,7 +23,7 @@ llvm::Expected DataBreakpointInfoRequestHandler::Run( const protocol::DataBreakpointInfoArguments &args) const { protocol::DataBreakpointInfoResponseBody response; - lldb::SBFrame frame = dap.GetLLDBFrame(args.frameId.value_or(UINT64_MAX)); + lldb::SBFrame frame = dap.GetLLDBFrame(args.frameId); lldb::SBValue variable = dap.variables.FindVariable( args.variablesReference.value_or(0), args.name); std::string addr, size; diff --git a/lldb/tools/lldb-dap/Handler/ModuleSymbolsRequestHandler.cpp b/lldb/tools/lldb-dap/Handler/ModuleSymbolsRequestHandler.cpp new file mode 100644 index 0000000000000..4a9d256cfa975 --- /dev/null +++ b/lldb/tools/lldb-dap/Handler/ModuleSymbolsRequestHandler.cpp @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DAP.h" +#include "DAPError.h" +#include "Protocol/DAPTypes.h" +#include "RequestHandler.h" +#include "lldb/API/SBAddress.h" +#include "lldb/API/SBFileSpec.h" +#include "lldb/API/SBModule.h" +#include "lldb/API/SBModuleSpec.h" +#include "lldb/Utility/UUID.h" +#include "llvm/Support/Error.h" +#include + +using namespace lldb_dap::protocol; +namespace lldb_dap { + +llvm::Expected +ModuleSymbolsRequestHandler::Run(const ModuleSymbolsArguments &args) const { + ModuleSymbolsResponseBody response; + + lldb::SBModuleSpec module_spec; + if (!args.moduleId.empty()) { + llvm::SmallVector uuid_bytes; + if (!lldb_private::UUID::DecodeUUIDBytesFromString(args.moduleId, + uuid_bytes) + .empty()) + return llvm::make_error("invalid module ID"); + + module_spec.SetUUIDBytes(uuid_bytes.data(), uuid_bytes.size()); + } + + if (!args.moduleName.empty()) { + lldb::SBFileSpec file_spec; + file_spec.SetFilename(args.moduleName.c_str()); + module_spec.SetFileSpec(file_spec); + } + + // Empty request, return empty response. + if (!module_spec.IsValid()) + return response; + + std::vector &symbols = response.symbols; + lldb::SBModule module = dap.target.FindModule(module_spec); + if (!module.IsValid()) + return llvm::make_error("module not found"); + + const size_t num_symbols = module.GetNumSymbols(); + const size_t start_index = args.startIndex.value_or(0); + const size_t end_index = + std::min(start_index + args.count.value_or(num_symbols), num_symbols); + for (size_t i = start_index; i < end_index; ++i) { + lldb::SBSymbol symbol = module.GetSymbolAtIndex(i); + if (!symbol.IsValid()) + continue; + + Symbol dap_symbol; + dap_symbol.id = symbol.GetID(); + dap_symbol.type = symbol.GetType(); + dap_symbol.isDebug = symbol.IsDebug(); + dap_symbol.isSynthetic = symbol.IsSynthetic(); + dap_symbol.isExternal = symbol.IsExternal(); + + lldb::SBAddress start_address = symbol.GetStartAddress(); + if (start_address.IsValid()) { + lldb::addr_t file_address = start_address.GetFileAddress(); + if (file_address != LLDB_INVALID_ADDRESS) + dap_symbol.fileAddress = file_address; + + lldb::addr_t load_address = start_address.GetLoadAddress(dap.target); + if (load_address != LLDB_INVALID_ADDRESS) + dap_symbol.loadAddress = load_address; + } + + dap_symbol.size = symbol.GetSize(); + if (const char *symbol_name = symbol.GetName()) + dap_symbol.name = symbol_name; + symbols.push_back(std::move(dap_symbol)); + } + + return response; +} + +} // namespace lldb_dap diff --git a/lldb/tools/lldb-dap/Handler/RequestHandler.h b/lldb/tools/lldb-dap/Handler/RequestHandler.h index 16f8062f97d7b..977a247996750 100644 --- a/lldb/tools/lldb-dap/Handler/RequestHandler.h +++ b/lldb/tools/lldb-dap/Handler/RequestHandler.h @@ -243,14 +243,17 @@ class BreakpointLocationsRequestHandler uint32_t end_line) const; }; -class CompletionsRequestHandler : public LegacyRequestHandler { +class CompletionsRequestHandler + : public RequestHandler> { public: - using LegacyRequestHandler::LegacyRequestHandler; + using RequestHandler::RequestHandler; static llvm::StringLiteral GetCommand() { return "completions"; } FeatureSet GetSupportedFeatures() const override { return {protocol::eAdapterFeatureCompletionsRequest}; } - void operator()(const llvm::json::Object &request) const override; + llvm::Expected + Run(const protocol::CompletionsArguments &args) const override; }; class ContinueRequestHandler @@ -594,6 +597,20 @@ class CancelRequestHandler : public RequestHandler> { +public: + using RequestHandler::RequestHandler; + static llvm::StringLiteral GetCommand() { return "__lldb_moduleSymbols"; } + FeatureSet GetSupportedFeatures() const override { + return {protocol::eAdapterFeatureSupportsModuleSymbolsRequest}; + } + llvm::Expected + Run(const protocol::ModuleSymbolsArguments &args) const override; +}; + /// A request used in testing to get the details on all breakpoints that are /// currently set in the target. This helps us to test "setBreakpoints" and /// "setFunctionBreakpoints" requests to verify we have the correct set of diff --git a/lldb/tools/lldb-dap/Protocol/DAPTypes.cpp b/lldb/tools/lldb-dap/Protocol/DAPTypes.cpp index ecb4baef56e80..a14ed9e521f48 100644 --- a/lldb/tools/lldb-dap/Protocol/DAPTypes.cpp +++ b/lldb/tools/lldb-dap/Protocol/DAPTypes.cpp @@ -1,4 +1,6 @@ #include "Protocol/DAPTypes.h" +#include "lldb/API/SBSymbol.h" +#include "lldb/lldb-enumerations.h" using namespace llvm; @@ -33,4 +35,35 @@ llvm::json::Value toJSON(const SourceLLDBData &SLD) { return result; } -} // namespace lldb_dap::protocol \ No newline at end of file +bool fromJSON(const llvm::json::Value &Params, Symbol &DS, llvm::json::Path P) { + json::ObjectMapper O(Params, P); + std::string type_str; + if (!(O && O.map("id", DS.id) && O.map("isDebug", DS.isDebug) && + O.map("isSynthetic", DS.isSynthetic) && + O.map("isExternal", DS.isExternal) && O.map("type", type_str) && + O.map("fileAddress", DS.fileAddress) && + O.mapOptional("loadAddress", DS.loadAddress) && + O.map("size", DS.size) && O.map("name", DS.name))) + return false; + + DS.type = lldb::SBSymbol::GetTypeFromString(type_str.c_str()); + return true; +} + +llvm::json::Value toJSON(const Symbol &DS) { + json::Object result{ + {"id", DS.id}, + {"isDebug", DS.isDebug}, + {"isSynthetic", DS.isSynthetic}, + {"isExternal", DS.isExternal}, + {"type", lldb::SBSymbol::GetTypeAsString(DS.type)}, + {"fileAddress", DS.fileAddress}, + {"loadAddress", DS.loadAddress}, + {"size", DS.size}, + {"name", DS.name}, + }; + + return result; +} + +} // namespace lldb_dap::protocol diff --git a/lldb/tools/lldb-dap/Protocol/DAPTypes.h b/lldb/tools/lldb-dap/Protocol/DAPTypes.h index 716d8b491b258..7fccf1359a737 100644 --- a/lldb/tools/lldb-dap/Protocol/DAPTypes.h +++ b/lldb/tools/lldb-dap/Protocol/DAPTypes.h @@ -48,6 +48,38 @@ struct SourceLLDBData { bool fromJSON(const llvm::json::Value &, SourceLLDBData &, llvm::json::Path); llvm::json::Value toJSON(const SourceLLDBData &); +struct Symbol { + /// The symbol id, usually the original symbol table index. + uint32_t id; + + /// True if this symbol is debug information in a symbol. + bool isDebug; + + /// True if this symbol is not actually in the symbol table, but synthesized + /// from other info in the object file. + bool isSynthetic; + + /// True if this symbol is globally visible. + bool isExternal; + + /// The symbol type. + lldb::SymbolType type; + + /// The symbol file address. + lldb::addr_t fileAddress; + + /// The symbol load address. + std::optional loadAddress; + + /// The symbol size. + lldb::addr_t size; + + /// The symbol name. + std::string name; +}; +bool fromJSON(const llvm::json::Value &, Symbol &, llvm::json::Path); +llvm::json::Value toJSON(const Symbol &); + } // namespace lldb_dap::protocol #endif diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp index bc4fee4aa8b8d..9cd9028d879e9 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp @@ -98,6 +98,10 @@ bool fromJSON(json::Value const &Params, Request &R, json::Path P) { return mapRaw(Params, "arguments", R.arguments, P); } +bool operator==(const Request &a, const Request &b) { + return a.seq == b.seq && a.command == b.command && a.arguments == b.arguments; +} + json::Value toJSON(const Response &R) { json::Object Result{{"type", "response"}, {"seq", 0}, @@ -177,6 +181,11 @@ bool fromJSON(json::Value const &Params, Response &R, json::Path P) { mapRaw(Params, "body", R.body, P); } +bool operator==(const Response &a, const Response &b) { + return a.request_seq == b.request_seq && a.command == b.command && + a.success == b.success && a.message == b.message && a.body == b.body; +} + json::Value toJSON(const ErrorMessage &EM) { json::Object Result{{"id", EM.id}, {"format", EM.format}}; @@ -248,6 +257,10 @@ bool fromJSON(json::Value const &Params, Event &E, json::Path P) { return mapRaw(Params, "body", E.body, P); } +bool operator==(const Event &a, const Event &b) { + return a.event == b.event && a.body == b.body; +} + bool fromJSON(const json::Value &Params, Message &PM, json::Path P) { json::ObjectMapper O(Params, P); if (!O) diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 81496380d412f..0a9ef538a7398 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -52,6 +52,7 @@ struct Request { }; llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); +bool operator==(const Request &, const Request &); /// A debug adapter initiated event. struct Event { @@ -63,6 +64,7 @@ struct Event { }; llvm::json::Value toJSON(const Event &); bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); +bool operator==(const Event &, const Event &); enum ResponseMessage : unsigned { /// The request was cancelled @@ -101,6 +103,7 @@ struct Response { }; bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); llvm::json::Value toJSON(const Response &); +bool operator==(const Response &, const Response &); /// A structured message object. Used to return errors from requests. struct ErrorMessage { @@ -140,6 +143,7 @@ llvm::json::Value toJSON(const ErrorMessage &); using Message = std::variant; bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); +bool operator==(const Message &, const Message &); inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Message &V) { OS << toJSON(V); diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.cpp index 29855ca50e9e0..e1806d6230a80 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.cpp +++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.cpp @@ -219,7 +219,7 @@ bool fromJSON(const json::Value &Params, InitializeRequestArguments &IRA, OM.map("clientName", IRA.clientName) && OM.map("locale", IRA.locale) && OM.map("linesStartAt1", IRA.linesStartAt1) && OM.map("columnsStartAt1", IRA.columnsStartAt1) && - OM.map("pathFormat", IRA.pathFormat) && + OM.mapOptional("pathFormat", IRA.pathFormat) && OM.map("$__lldb_sourceInitFile", IRA.lldbExtSourceInitFile); } @@ -329,6 +329,17 @@ json::Value toJSON(const ContinueResponseBody &CRB) { return std::move(Body); } +bool fromJSON(const json::Value &Params, CompletionsArguments &CA, + json::Path P) { + json::ObjectMapper O(Params, P); + return O && O.map("text", CA.text) && O.map("column", CA.column) && + O.mapOptional("frameId", CA.frameId) && O.mapOptional("line", CA.line); +} + +json::Value toJSON(const CompletionsResponseBody &CRB) { + return json::Object{{"targets", CRB.targets}}; +} + bool fromJSON(const json::Value &Params, SetVariableArguments &SVA, json::Path P) { json::ObjectMapper O(Params, P); @@ -598,4 +609,19 @@ json::Value toJSON(const WriteMemoryResponseBody &WMR) { return result; } +bool fromJSON(const llvm::json::Value &Params, ModuleSymbolsArguments &Args, + llvm::json::Path P) { + json::ObjectMapper O(Params, P); + return O && O.map("moduleId", Args.moduleId) && + O.map("moduleName", Args.moduleName) && + O.mapOptional("startIndex", Args.startIndex) && + O.mapOptional("count", Args.count); +} + +llvm::json::Value toJSON(const ModuleSymbolsResponseBody &DGMSR) { + json::Object result; + result.insert({"symbols", DGMSR.symbols}); + return result; +} + } // namespace lldb_dap::protocol diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h index c45ee10e77d1c..0848ee53b4410 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolRequests.h @@ -310,6 +310,8 @@ bool fromJSON(const llvm::json::Value &, LaunchRequestArguments &, using LaunchResponse = VoidResponse; #define LLDB_DAP_INVALID_PORT -1 +/// An invalid 'frameId' default value. +#define LLDB_DAP_INVALID_FRAME_ID UINT64_MAX /// lldb-dap specific attach arguments. struct AttachRequestArguments { @@ -376,6 +378,35 @@ struct ContinueResponseBody { }; llvm::json::Value toJSON(const ContinueResponseBody &); +/// Arguments for `completions` request. +struct CompletionsArguments { + /// Returns completions in the scope of this stack frame. If not specified, + /// the completions are returned for the global scope. + uint64_t frameId = LLDB_DAP_INVALID_FRAME_ID; + + /// One or more source lines. Typically this is the text users have typed into + /// the debug console before they asked for completion. + std::string text; + + /// The position within `text` for which to determine the completion + /// proposals. It is measured in UTF-16 code units and the client capability + /// `columnsStartAt1` determines whether it is 0- or 1-based. + int64_t column = 0; + + /// A line for which to determine the completion proposals. If missing the + /// first line of the text is assumed. + int64_t line = 0; +}; +bool fromJSON(const llvm::json::Value &, CompletionsArguments &, + llvm::json::Path); + +/// Response to `completions` request. +struct CompletionsResponseBody { + /// The possible completions for a given caret position and text. + std::vector targets; +}; +llvm::json::Value toJSON(const CompletionsResponseBody &); + /// Arguments for `configurationDone` request. using ConfigurationDoneArguments = EmptyArguments; @@ -455,7 +486,7 @@ struct ScopesArguments { /// Retrieve the scopes for the stack frame identified by `frameId`. The /// `frameId` must have been obtained in the current suspended state. See /// 'Lifetime of Object References' in the Overview section for details. - uint64_t frameId = LLDB_INVALID_FRAME_ID; + uint64_t frameId = LLDB_DAP_INVALID_FRAME_ID; }; bool fromJSON(const llvm::json::Value &, ScopesArguments &, llvm::json::Path); @@ -541,7 +572,7 @@ using StepInResponse = VoidResponse; /// Arguments for `stepInTargets` request. struct StepInTargetsArguments { /// The stack frame for which to retrieve the possible step-in targets. - uint64_t frameId = LLDB_INVALID_FRAME_ID; + uint64_t frameId = LLDB_DAP_INVALID_FRAME_ID; }; bool fromJSON(const llvm::json::Value &, StepInTargetsArguments &, llvm::json::Path); @@ -690,7 +721,7 @@ struct DataBreakpointInfoArguments { /// When `name` is an expression, evaluate it in the scope of this stack /// frame. If not specified, the expression is evaluated in the global scope. /// When `asAddress` is true, the `frameId` is ignored. - std::optional frameId; + uint64_t frameId = LLDB_DAP_INVALID_FRAME_ID; /// If specified, a debug adapter should return information for the range of /// memory extending `bytes` number of bytes from the address or variable @@ -981,6 +1012,30 @@ struct WriteMemoryResponseBody { }; llvm::json::Value toJSON(const WriteMemoryResponseBody &); +struct ModuleSymbolsArguments { + /// The module UUID for which to retrieve symbols. + std::string moduleId; + + /// The module path. + std::string moduleName; + + /// The index of the first symbol to return; if omitted, start at the + /// beginning. + std::optional startIndex; + + /// The number of symbols to return; if omitted, all symbols are returned. + std::optional count; +}; +bool fromJSON(const llvm::json::Value &, ModuleSymbolsArguments &, + llvm::json::Path); + +/// Response to `getModuleSymbols` request. +struct ModuleSymbolsResponseBody { + /// The symbols for the specified module. + std::vector symbols; +}; +llvm::json::Value toJSON(const ModuleSymbolsResponseBody &); + } // namespace lldb_dap::protocol #endif diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.cpp index 369858c3a5f4b..dc8edaadcd9bb 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.cpp +++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.cpp @@ -200,6 +200,123 @@ bool fromJSON(const llvm::json::Value &Params, ChecksumAlgorithm &CA, return true; } +bool fromJSON(const json::Value &Params, CompletionItemType &CIT, + json::Path P) { + auto raw_item_type = Params.getAsString(); + if (!raw_item_type) { + P.report("expected a string"); + return false; + } + + std::optional item_type = + StringSwitch>(*raw_item_type) + .Case("method", eCompletionItemTypeMethod) + .Case("function", eCompletionItemTypeFunction) + .Case("constructor", eCompletionItemTypeConstructor) + .Case("field", eCompletionItemTypeField) + .Case("variable", eCompletionItemTypeVariable) + .Case("class", eCompletionItemTypeClass) + .Case("interface", eCompletionItemTypeInterface) + .Case("module", eCompletionItemTypeModule) + .Case("property", eCompletionItemTypeProperty) + .Case("unit", eCompletionItemTypeUnit) + .Case("value", eCompletionItemTypeValue) + .Case("enum", eCompletionItemTypeEnum) + .Case("keyword", eCompletionItemTypeKeyword) + .Case("snippet", eCompletionItemTypeSnippet) + .Case("text", eCompletionItemTypeText) + .Case("color", eCompletionItemTypeColor) + .Case("file", eCompletionItemTypeFile) + .Case("reference", eCompletionItemTypeReference) + .Case("customcolor", eCompletionItemTypeCustomColor) + .Default(std::nullopt); + + if (!item_type) { + P.report("unexpected value"); + return false; + } + + CIT = *item_type; + return true; +} + +json::Value toJSON(const CompletionItemType &CIT) { + switch (CIT) { + case eCompletionItemTypeMethod: + return "method"; + case eCompletionItemTypeFunction: + return "function"; + case eCompletionItemTypeConstructor: + return "constructor"; + case eCompletionItemTypeField: + return "field"; + case eCompletionItemTypeVariable: + return "variable"; + case eCompletionItemTypeClass: + return "class"; + case eCompletionItemTypeInterface: + return "interface"; + case eCompletionItemTypeModule: + return "module"; + case eCompletionItemTypeProperty: + return "property"; + case eCompletionItemTypeUnit: + return "unit"; + case eCompletionItemTypeValue: + return "value"; + case eCompletionItemTypeEnum: + return "enum"; + case eCompletionItemTypeKeyword: + return "keyword"; + case eCompletionItemTypeSnippet: + return "snippet"; + case eCompletionItemTypeText: + return "text"; + case eCompletionItemTypeColor: + return "color"; + case eCompletionItemTypeFile: + return "file"; + case eCompletionItemTypeReference: + return "reference"; + case eCompletionItemTypeCustomColor: + return "customcolor"; + } + llvm_unreachable("unhandled CompletionItemType."); +} + +bool fromJSON(const json::Value &Params, CompletionItem &CI, json::Path P) { + json::ObjectMapper O(Params, P); + return O && O.map("label", CI.label) && O.mapOptional("text", CI.text) && + O.mapOptional("sortText", CI.sortText) && + O.mapOptional("detail", CI.detail) && O.mapOptional("type", CI.type) && + O.mapOptional("start", CI.start) && + O.mapOptional("length", CI.length) && + O.mapOptional("selectionStart", CI.selectionStart) && + O.mapOptional("selectionLength", CI.selectionLength); +} +json::Value toJSON(const CompletionItem &CI) { + json::Object result{{"label", CI.label}}; + + if (!CI.text.empty()) + result.insert({"text", CI.text}); + if (!CI.sortText.empty()) + result.insert({"sortText", CI.sortText}); + if (!CI.detail.empty()) + result.insert({"detail", CI.detail}); + if (CI.type) + result.insert({"type", CI.type}); + if (CI.start) + result.insert({"start", CI.start}); + if (CI.length) + result.insert({"length", CI.length}); + if (CI.selectionStart) + result.insert({"selectionStart", CI.selectionStart}); + if (CI.selectionLength) + result.insert({"selectionLength", CI.selectionLength}); + + return result; +} + json::Value toJSON(const BreakpointModeApplicability &BMA) { switch (BMA) { case eBreakpointModeApplicabilitySource: @@ -335,6 +452,8 @@ static llvm::StringLiteral ToString(AdapterFeature feature) { return "supportsWriteMemoryRequest"; case eAdapterFeatureTerminateDebuggee: return "supportTerminateDebuggee"; + case eAdapterFeatureSupportsModuleSymbolsRequest: + return "supportsModuleSymbolsRequest"; } llvm_unreachable("unhandled adapter feature."); } @@ -406,6 +525,8 @@ bool fromJSON(const llvm::json::Value &Params, AdapterFeature &feature, eAdapterFeatureValueFormattingOptions) .Case("supportsWriteMemoryRequest", eAdapterFeatureWriteMemoryRequest) .Case("supportTerminateDebuggee", eAdapterFeatureTerminateDebuggee) + .Case("supportsModuleSymbolsRequest", + eAdapterFeatureSupportsModuleSymbolsRequest) .Default(std::nullopt); if (!parsedFeature) { diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h index c4be7911a662b..7077df90a85b5 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolTypes.h @@ -109,6 +109,84 @@ enum ChecksumAlgorithm : unsigned { bool fromJSON(const llvm::json::Value &, ChecksumAlgorithm &, llvm::json::Path); llvm::json::Value toJSON(const ChecksumAlgorithm &); +/// Some predefined types for the CompletionItem. Please note that not all +/// clients have specific icons for all of them. +enum CompletionItemType : unsigned { + eCompletionItemTypeMethod, + eCompletionItemTypeFunction, + eCompletionItemTypeConstructor, + eCompletionItemTypeField, + eCompletionItemTypeVariable, + eCompletionItemTypeClass, + eCompletionItemTypeInterface, + eCompletionItemTypeModule, + eCompletionItemTypeProperty, + eCompletionItemTypeUnit, + eCompletionItemTypeValue, + eCompletionItemTypeEnum, + eCompletionItemTypeKeyword, + eCompletionItemTypeSnippet, + eCompletionItemTypeText, + eCompletionItemTypeColor, + eCompletionItemTypeFile, + eCompletionItemTypeReference, + eCompletionItemTypeCustomColor, +}; +bool fromJSON(const llvm::json::Value &, CompletionItemType &, + llvm::json::Path); +llvm::json::Value toJSON(const CompletionItemType &); + +/// `CompletionItems` are the suggestions returned from the `completions` +/// request. +struct CompletionItem { + /// The label of this completion item. By default this is also the text that + /// is inserted when selecting this completion. + std::string label; + + /// If text is returned and not an empty string, then it is inserted instead + /// of the label. + std::string text; + + /// A string that should be used when comparing this item with other items. If + /// not returned or an empty string, the `label` is used instead. + std::string sortText; + + /// A human-readable string with additional information about this item, like + /// type or symbol information. + std::string detail; + + /// The item's type. Typically the client uses this information to render the + /// item in the UI with an icon. + std::optional type; + + /// Start position (within the `text` attribute of the `completions` + /// request) where the completion text is added. The position is measured in + /// UTF-16 code units and the client capability `columnsStartAt1` determines + /// whether it is 0- or 1-based. If the start position is omitted the text + /// is added at the location specified by the `column` attribute of the + /// `completions` request. + int64_t start = 0; + + /// Length determines how many characters are overwritten by the completion + /// text and it is measured in UTF-16 code units. If missing the value 0 is + /// assumed which results in the completion text being inserted. + int64_t length = 0; + + /// Determines the start of the new selection after the text has been + /// inserted (or replaced). `selectionStart` is measured in UTF-16 code + /// units and must be in the range 0 and length of the completion text. If + /// omitted the selection starts at the end of the completion text. + int64_t selectionStart = 0; + + /// Determines the length of the new selection after the text has been + /// inserted (or replaced) and it is measured in UTF-16 code units. The + /// selection can not extend beyond the bounds of the completion text. If + /// omitted the length is assumed to be 0. + int64_t selectionLength = 0; +}; +bool fromJSON(const llvm::json::Value &, CompletionItem &, llvm::json::Path); +llvm::json::Value toJSON(const CompletionItem &); + /// Describes one or more type of breakpoint a BreakpointMode applies to. This /// is a non-exhaustive enumeration and may expand as future breakpoint types /// are added. @@ -242,8 +320,11 @@ enum AdapterFeature : unsigned { /// The debug adapter supports the `terminateDebuggee` attribute on the /// `disconnect` request. eAdapterFeatureTerminateDebuggee, + /// The debug adapter supports the `supportsModuleSymbols` request. + /// This request is a custom request of lldb-dap. + eAdapterFeatureSupportsModuleSymbolsRequest, eAdapterFeatureFirst = eAdapterFeatureANSIStyling, - eAdapterFeatureLast = eAdapterFeatureTerminateDebuggee, + eAdapterFeatureLast = eAdapterFeatureSupportsModuleSymbolsRequest, }; bool fromJSON(const llvm::json::Value &, AdapterFeature &, llvm::json::Path); llvm::json::Value toJSON(const AdapterFeature &); diff --git a/lldb/tools/lldb-dap/Transport.cpp b/lldb/tools/lldb-dap/Transport.cpp index d602920da34e3..8f71f88cae1f7 100644 --- a/lldb/tools/lldb-dap/Transport.cpp +++ b/lldb/tools/lldb-dap/Transport.cpp @@ -14,7 +14,8 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; -using namespace lldb_dap; + +namespace lldb_dap { Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output) @@ -24,3 +25,5 @@ Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, void Transport::Log(llvm::StringRef message) { DAP_LOG(m_log, "({0}) {1}", m_client_name, message); } + +} // namespace lldb_dap diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 51f62e718a0d0..4a9dd76c2303e 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -15,6 +15,7 @@ #define LLDB_TOOLS_LLDB_DAP_TRANSPORT_H #include "DAPForward.h" +#include "Protocol/ProtocolBase.h" #include "lldb/Host/JSONTransport.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" @@ -23,17 +24,15 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport : public lldb_private::HTTPDelimitedJSONTransport { +class Transport final + : public lldb_private::HTTPDelimitedJSONTransport< + protocol::Request, protocol::Response, protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); virtual ~Transport() = default; - virtual void Log(llvm::StringRef message) override; - - /// Returns the name of this transport client, for example `stdin/stdout` or - /// `client_1`. - llvm::StringRef GetClientName() { return m_client_name; } + void Log(llvm::StringRef message) override; private: llvm::StringRef m_client_name; diff --git a/lldb/tools/lldb-dap/package-lock.json b/lldb/tools/lldb-dap/package-lock.json index 1969b196accc6..26db1ce6df2fd 100644 --- a/lldb/tools/lldb-dap/package-lock.json +++ b/lldb/tools/lldb-dap/package-lock.json @@ -1,20 +1,24 @@ { "name": "lldb-dap", - "version": "0.2.15", + "version": "0.2.16", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "lldb-dap", - "version": "0.2.15", + "version": "0.2.16", "license": "Apache 2.0 License with LLVM exceptions", "devDependencies": { "@types/node": "^18.19.41", + "@types/tabulator-tables": "^6.2.10", "@types/vscode": "1.75.0", + "@types/vscode-webview": "^1.57.5", "@vscode/debugprotocol": "^1.68.0", "@vscode/vsce": "^3.2.2", + "esbuild": "^0.25.9", "prettier": "^3.4.2", "prettier-plugin-curly": "^0.3.1", + "tabulator-tables": "^6.3.1", "typescript": "^5.7.3" }, "engines": { @@ -318,6 +322,448 @@ "node": ">=6.9.0" } }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.9.tgz", + "integrity": "sha512-OaGtL73Jck6pBKjNIe24BnFE6agGl+6KxDtTfHhy1HmhthfKouEcOhqpSL64K4/0WCtbKFLOdzD/44cJ4k9opA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.9.tgz", + "integrity": "sha512-5WNI1DaMtxQ7t7B6xa572XMXpHAaI/9Hnhk8lcxF4zVN4xstUgTlvuGDorBguKEnZO70qwEcLpfifMLoxiPqHQ==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.9.tgz", + "integrity": "sha512-IDrddSmpSv51ftWslJMvl3Q2ZT98fUSL2/rlUXuVqRXHCs5EUF1/f+jbjF5+NG9UffUDMCiTyh8iec7u8RlTLg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.9.tgz", + "integrity": "sha512-I853iMZ1hWZdNllhVZKm34f4wErd4lMyeV7BLzEExGEIZYsOzqDWDf+y082izYUE8gtJnYHdeDpN/6tUdwvfiw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.9.tgz", + "integrity": "sha512-XIpIDMAjOELi/9PB30vEbVMs3GV1v2zkkPnuyRRURbhqjyzIINwj+nbQATh4H9GxUgH1kFsEyQMxwiLFKUS6Rg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.9.tgz", + "integrity": "sha512-jhHfBzjYTA1IQu8VyrjCX4ApJDnH+ez+IYVEoJHeqJm9VhG9Dh2BYaJritkYK3vMaXrf7Ogr/0MQ8/MeIefsPQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.9.tgz", + "integrity": "sha512-z93DmbnY6fX9+KdD4Ue/H6sYs+bhFQJNCPZsi4XWJoYblUqT06MQUdBCpcSfuiN72AbqeBFu5LVQTjfXDE2A6Q==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.9.tgz", + "integrity": "sha512-mrKX6H/vOyo5v71YfXWJxLVxgy1kyt1MQaD8wZJgJfG4gq4DpQGpgTB74e5yBeQdyMTbgxp0YtNj7NuHN0PoZg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.9.tgz", + "integrity": "sha512-HBU2Xv78SMgaydBmdor38lg8YDnFKSARg1Q6AT0/y2ezUAKiZvc211RDFHlEZRFNRVhcMamiToo7bDx3VEOYQw==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.9.tgz", + "integrity": "sha512-BlB7bIcLT3G26urh5Dmse7fiLmLXnRlopw4s8DalgZ8ef79Jj4aUcYbk90g8iCa2467HX8SAIidbL7gsqXHdRw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.9.tgz", + "integrity": "sha512-e7S3MOJPZGp2QW6AK6+Ly81rC7oOSerQ+P8L0ta4FhVi+/j/v2yZzx5CqqDaWjtPFfYz21Vi1S0auHrap3Ma3A==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.9.tgz", + "integrity": "sha512-Sbe10Bnn0oUAB2AalYztvGcK+o6YFFA/9829PhOCUS9vkJElXGdphz0A3DbMdP8gmKkqPmPcMJmJOrI3VYB1JQ==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.9.tgz", + "integrity": "sha512-YcM5br0mVyZw2jcQeLIkhWtKPeVfAerES5PvOzaDxVtIyZ2NUBZKNLjC5z3/fUlDgT6w89VsxP2qzNipOaaDyA==", + "cpu": [ + "mips64el" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.9.tgz", + "integrity": "sha512-++0HQvasdo20JytyDpFvQtNrEsAgNG2CY1CLMwGXfFTKGBGQT3bOeLSYE2l1fYdvML5KUuwn9Z8L1EWe2tzs1w==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.9.tgz", + "integrity": "sha512-uNIBa279Y3fkjV+2cUjx36xkx7eSjb8IvnL01eXUKXez/CBHNRw5ekCGMPM0BcmqBxBcdgUWuUXmVWwm4CH9kg==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.9.tgz", + "integrity": "sha512-Mfiphvp3MjC/lctb+7D287Xw1DGzqJPb/J2aHHcHxflUo+8tmN/6d4k6I2yFR7BVo5/g7x2Monq4+Yew0EHRIA==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.9.tgz", + "integrity": "sha512-iSwByxzRe48YVkmpbgoxVzn76BXjlYFXC7NvLYq+b+kDjyyk30J0JY47DIn8z1MO3K0oSl9fZoRmZPQI4Hklzg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.9.tgz", + "integrity": "sha512-9jNJl6FqaUG+COdQMjSCGW4QiMHH88xWbvZ+kRVblZsWrkXlABuGdFJ1E9L7HK+T0Yqd4akKNa/lO0+jDxQD4Q==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.9.tgz", + "integrity": "sha512-RLLdkflmqRG8KanPGOU7Rpg829ZHu8nFy5Pqdi9U01VYtG9Y0zOG6Vr2z4/S+/3zIyOxiK6cCeYNWOFR9QP87g==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.9.tgz", + "integrity": "sha512-YaFBlPGeDasft5IIM+CQAhJAqS3St3nJzDEgsgFixcfZeyGPCd6eJBWzke5piZuZ7CtL656eOSYKk4Ls2C0FRQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.9.tgz", + "integrity": "sha512-1MkgTCuvMGWuqVtAvkpkXFmtL8XhWy+j4jaSO2wxfJtilVCi0ZE37b8uOdMItIHz4I6z1bWWtEX4CJwcKYLcuA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openharmony-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.25.9.tgz", + "integrity": "sha512-4Xd0xNiMVXKh6Fa7HEJQbrpP3m3DDn43jKxMjxLLRjWnRsfxjORYJlXPO4JNcXtOyfajXorRKY9NkOpTHptErg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.9.tgz", + "integrity": "sha512-WjH4s6hzo00nNezhp3wFIAfmGZ8U7KtrJNlFMRKxiI9mxEK1scOMAaa9i4crUtu+tBr+0IN6JCuAcSBJZfnphw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.9.tgz", + "integrity": "sha512-mGFrVJHmZiRqmP8xFOc6b84/7xa5y5YvR1x8djzXpJBSv/UsNK6aqec+6JDjConTgvvQefdGhFDAs2DLAds6gQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.9.tgz", + "integrity": "sha512-b33gLVU2k11nVx1OhX3C8QQP6UHQK4ZtN56oFWvVXvz2VkDoe6fbG8TOgHFxEvqeqohmRnIHe5A1+HADk4OQww==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.9.tgz", + "integrity": "sha512-PPOl1mi6lpLNQxnGoyAfschAodRFYXJ+9fs6WHXz7CSWKbOqiMZsubC+BQsVKuul+3vKLuwTHsS2c2y9EoKwxQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, "node_modules/@isaacs/cliui": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", @@ -399,6 +845,13 @@ "undici-types": "~5.26.4" } }, + "node_modules/@types/tabulator-tables": { + "version": "6.2.10", + "resolved": "https://registry.npmjs.org/@types/tabulator-tables/-/tabulator-tables-6.2.10.tgz", + "integrity": "sha512-g6o0gG3lu/ozmxPw9rLY1p57T6rvV8OhbJKyzWwPwjdnN3JuSQ3gWxb06v2+dl2tdoqNXTvlylipSSKpS8UzzQ==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/vscode": { "version": "1.75.0", "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.75.0.tgz", @@ -406,6 +859,13 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/vscode-webview": { + "version": "1.57.5", + "resolved": "https://registry.npmjs.org/@types/vscode-webview/-/vscode-webview-1.57.5.tgz", + "integrity": "sha512-iBAUYNYkz+uk1kdsq05fEcoh8gJmwT3lqqFPN7MGyjQ3HVloViMdo7ZJ8DFIP8WOK74PjOEilosqAyxV2iUFUw==", + "dev": true, + "license": "MIT" + }, "node_modules/@vscode/debugprotocol": { "version": "1.68.0", "resolved": "https://registry.npmjs.org/@vscode/debugprotocol/-/debugprotocol-1.68.0.tgz", @@ -1169,6 +1629,48 @@ "node": ">= 0.4" } }, + "node_modules/esbuild": { + "version": "0.25.9", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.9.tgz", + "integrity": "sha512-CRbODhYyQx3qp7ZEwzxOk4JBqmD/seJrzPa/cGjY1VtIn5E09Oi9/dB4JwctnfZ8Q8iT7rioVv5k/FNT/uf54g==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.25.9", + "@esbuild/android-arm": "0.25.9", + "@esbuild/android-arm64": "0.25.9", + "@esbuild/android-x64": "0.25.9", + "@esbuild/darwin-arm64": "0.25.9", + "@esbuild/darwin-x64": "0.25.9", + "@esbuild/freebsd-arm64": "0.25.9", + "@esbuild/freebsd-x64": "0.25.9", + "@esbuild/linux-arm": "0.25.9", + "@esbuild/linux-arm64": "0.25.9", + "@esbuild/linux-ia32": "0.25.9", + "@esbuild/linux-loong64": "0.25.9", + "@esbuild/linux-mips64el": "0.25.9", + "@esbuild/linux-ppc64": "0.25.9", + "@esbuild/linux-riscv64": "0.25.9", + "@esbuild/linux-s390x": "0.25.9", + "@esbuild/linux-x64": "0.25.9", + "@esbuild/netbsd-arm64": "0.25.9", + "@esbuild/netbsd-x64": "0.25.9", + "@esbuild/openbsd-arm64": "0.25.9", + "@esbuild/openbsd-x64": "0.25.9", + "@esbuild/openharmony-arm64": "0.25.9", + "@esbuild/sunos-x64": "0.25.9", + "@esbuild/win32-arm64": "0.25.9", + "@esbuild/win32-ia32": "0.25.9", + "@esbuild/win32-x64": "0.25.9" + } + }, "node_modules/escape-string-regexp": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz", @@ -2557,6 +3059,13 @@ "node": ">=4" } }, + "node_modules/tabulator-tables": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/tabulator-tables/-/tabulator-tables-6.3.1.tgz", + "integrity": "sha512-qFW7kfadtcaISQIibKAIy0f3eeIXUVi8242Vly1iJfMD79kfEGzfczNuPBN/80hDxHzQJXYbmJ8VipI40hQtfA==", + "dev": true, + "license": "MIT" + }, "node_modules/tar-fs": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz", diff --git a/lldb/tools/lldb-dap/package.json b/lldb/tools/lldb-dap/package.json index d677a81cc7974..8c6c1b4ae6ebb 100644 --- a/lldb/tools/lldb-dap/package.json +++ b/lldb/tools/lldb-dap/package.json @@ -29,11 +29,15 @@ ], "devDependencies": { "@types/node": "^18.19.41", + "@types/tabulator-tables": "^6.2.10", "@types/vscode": "1.75.0", + "@types/vscode-webview": "^1.57.5", "@vscode/debugprotocol": "^1.68.0", "@vscode/vsce": "^3.2.2", + "esbuild": "^0.25.9", "prettier": "^3.4.2", "prettier-plugin-curly": "^0.3.1", + "tabulator-tables": "^6.3.1", "typescript": "^5.7.3" }, "activationEvents": [ @@ -42,8 +46,11 @@ ], "main": "./out/extension", "scripts": { - "vscode:prepublish": "tsc -p ./", - "watch": "tsc -watch -p ./", + "bundle-symbols-table-view": "npx tsc -p src-ts/webview --noEmit && npx esbuild src-ts/webview/symbols-table-view.ts --bundle --format=iife --outdir=./out/webview", + "bundle-tabulator": "cp node_modules/tabulator-tables/dist/js/tabulator.min.js ./out/webview/ && cp node_modules/tabulator-tables/dist/css/tabulator_midnight.min.css ./out/webview/ && cp node_modules/tabulator-tables/dist/css/tabulator_simple.min.css ./out/webview/", + "bundle-webview": "npm run bundle-symbols-table-view && npm run bundle-tabulator", + "vscode:prepublish": "npm run bundle-webview && tsc -p ./", + "watch": "npm run bundle-webview && tsc -watch -p ./", "format": "npx prettier './src-ts/' --write", "package": "rm -rf ./out/lldb-dap.vsix && vsce package --out ./out/lldb-dap.vsix", "publish": "vsce publish", @@ -259,6 +266,15 @@ { "command": "lldb-dap.modules.copyProperty", "title": "Copy Value" + }, + { + "command": "lldb-dap.modules.showSymbols", + "title": "Show Module Symbols" + }, + { + "category": "lldb-dap", + "command": "lldb-dap.debug.showSymbols", + "title": "Show Symbols of a Module" } ], "menus": { @@ -266,12 +282,24 @@ { "command": "lldb-dap.modules.copyProperty", "when": "false" + }, + { + "command": "lldb-dap.modules.showSymbols", + "when": "false" + }, + { + "command": "lldb-dap.debug.showSymbols", + "when": "debuggersAvailable && debugType == 'lldb-dap' && lldb-dap.supportsModuleSymbolsRequest" } ], "view/item/context": [ { "command": "lldb-dap.modules.copyProperty", "when": "view == lldb-dap.modules && viewItem == property" + }, + { + "command": "lldb-dap.modules.showSymbols", + "when": "view == lldb-dap.modules && viewItem == module && lldb-dap.supportsModuleSymbolsRequest" } ] }, @@ -370,6 +398,29 @@ }, "markdownDescription": "The list of additional arguments used to launch the debug adapter executable. Overrides any user or workspace settings." }, + "debugAdapterEnv": { + "anyOf": [ + { + "type": "object", + "markdownDescription": "Additional environment variables to set when launching the debug adapter executable. E.g. `{ \"FOO\": \"1\" }`", + "patternProperties": { + ".*": { + "type": "string" + } + }, + "default": {} + }, + { + "type": "array", + "markdownDescription": "Additional environment variables to set when launching the debug adapter executable. E.g. `[\"FOO=1\", \"BAR\"]`", + "items": { + "type": "string", + "pattern": "^((\\w+=.*)|^\\w+)$" + }, + "default": [] + } + ] + }, "program": { "type": "string", "description": "Path to the program to debug." diff --git a/lldb/tools/lldb-dap/src-ts/debug-adapter-factory.ts b/lldb/tools/lldb-dap/src-ts/debug-adapter-factory.ts index 157aa2ac76a1f..f7e92ee95ca32 100644 --- a/lldb/tools/lldb-dap/src-ts/debug-adapter-factory.ts +++ b/lldb/tools/lldb-dap/src-ts/debug-adapter-factory.ts @@ -68,6 +68,40 @@ async function findDAPExecutable(): Promise { return undefined; } +/** + * Validates the DAP environment provided in the debug configuration. + * It must be a dictionary of string keys and values OR an array of string values. + * + * @param debugConfigEnv The supposed DAP environment that will be validated + * @returns Whether or not the DAP environment is valid + */ +function validateDAPEnv(debugConfigEnv: any): boolean { + // If the env is an object, it should have string values. + // The keys are guaranteed to be strings. + if ( + typeof debugConfigEnv === "object" && + Object.values(debugConfigEnv).findIndex( + (entry) => typeof entry !== "string", + ) !== -1 + ) { + return false; + } + + // If the env is an array, it should have string values which match the regex. + if ( + Array.isArray(debugConfigEnv) && + debugConfigEnv.findIndex( + (entry) => + typeof entry !== "string" || !/^((\\w+=.*)|^\\w+)$/.test(entry), + ) !== -1 + ) { + return false; + } + + // The env is valid. + return true; +} + /** * Retrieves the lldb-dap executable path either from settings or the provided * {@link vscode.DebugConfiguration}. @@ -157,6 +191,51 @@ async function getDAPArguments( .get("arguments", []); } +/** + * Retrieves the environment that will be provided to lldb-dap either from settings or the provided + * {@link vscode.DebugConfiguration}. + * + * @param workspaceFolder The {@link vscode.WorkspaceFolder} that the debug session will be launched within + * @param configuration The {@link vscode.DebugConfiguration} that will be launched + * @throws An {@link ErrorWithNotification} if something went wrong + * @returns The environment that will be provided to lldb-dap + */ +async function getDAPEnvironment( + workspaceFolder: vscode.WorkspaceFolder | undefined, + configuration: vscode.DebugConfiguration, +): Promise<{ [key: string]: string }> { + const debugConfigEnv = configuration.debugAdapterEnv; + if (debugConfigEnv) { + if (validateDAPEnv(debugConfigEnv) === false) { + throw new ErrorWithNotification( + "The debugAdapterEnv property must be a dictionary of string keys and values OR an array of string values. Please update your launch configuration", + new ConfigureButton(), + ); + } + + // Transform, so that the returned value is always a dictionary. + if (Array.isArray(debugConfigEnv)) { + const ret: { [key: string]: string } = {}; + for (const envVar of debugConfigEnv as string[]) { + const equalSignPos = envVar.search("="); + if (equalSignPos >= 0) { + ret[envVar.substr(0, equalSignPos)] = envVar.substr(equalSignPos + 1); + } else { + ret[envVar] = ""; + } + } + return ret; + } else { + return debugConfigEnv; + } + } + + const config = vscode.workspace.workspaceFile + ? vscode.workspace.getConfiguration("lldb-dap") + : vscode.workspace.getConfiguration("lldb-dap", workspaceFolder); + return config.get<{ [key: string]: string }>("environment") || {}; +} + /** * Creates a new {@link vscode.DebugAdapterExecutable} based on the provided workspace folder and * debug configuration. Assumes that the given debug configuration is for a local launch of lldb-dap. @@ -182,12 +261,16 @@ export async function createDebugAdapterExecutable( if (log_path) { env["LLDBDAP_LOG"] = log_path; } else if ( - vscode.workspace.getConfiguration("lldb-dap").get("captureSessionLogs", false) + vscode.workspace + .getConfiguration("lldb-dap") + .get("captureSessionLogs", false) ) { env["LLDBDAP_LOG"] = logFilePath.get(LogType.DEBUG_SESSION); } - const configEnvironment = - config.get<{ [key: string]: string }>("environment") || {}; + const configEnvironment = await getDAPEnvironment( + workspaceFolder, + configuration, + ); const dapPath = await getDAPExecutable(workspaceFolder, configuration); const dbgOptions = { diff --git a/lldb/tools/lldb-dap/src-ts/debug-configuration-provider.ts b/lldb/tools/lldb-dap/src-ts/debug-configuration-provider.ts index 1e16dac031125..1ae87116141f1 100644 --- a/lldb/tools/lldb-dap/src-ts/debug-configuration-provider.ts +++ b/lldb/tools/lldb-dap/src-ts/debug-configuration-provider.ts @@ -69,6 +69,10 @@ const configurations: Record = { terminateCommands: { type: "stringArray", default: [] }, }; +export function getDefaultConfigKey(key: string): string | number | boolean | string[] | undefined { + return configurations[key]?.default; +} + export class LLDBDapConfigurationProvider implements vscode.DebugConfigurationProvider { diff --git a/lldb/tools/lldb-dap/src-ts/debug-session-tracker.ts b/lldb/tools/lldb-dap/src-ts/debug-session-tracker.ts index 7d7f73dbff92d..6e89d441bbcf0 100644 --- a/lldb/tools/lldb-dap/src-ts/debug-session-tracker.ts +++ b/lldb/tools/lldb-dap/src-ts/debug-session-tracker.ts @@ -1,11 +1,17 @@ import { DebugProtocol } from "@vscode/debugprotocol"; import * as vscode from "vscode"; +export interface LLDBDapCapabilities extends DebugProtocol.Capabilities { + /** The debug adapter supports the `moduleSymbols` request. */ + supportsModuleSymbolsRequest?: boolean; +} + /** A helper type for mapping event types to their corresponding data type. */ // prettier-ignore interface EventMap { "module": DebugProtocol.ModuleEvent; "exited": DebugProtocol.ExitedEvent; + "capabilities": DebugProtocol.CapabilitiesEvent; } /** A type assertion to check if a ProtocolMessage is an event or if it is a specific event. */ @@ -39,6 +45,9 @@ export class DebugSessionTracker private modulesChanged = new vscode.EventEmitter< vscode.DebugSession | undefined >(); + private sessionReceivedCapabilities = + new vscode.EventEmitter<[ vscode.DebugSession, LLDBDapCapabilities ]>(); + private sessionExited = new vscode.EventEmitter(); /** * Fired when modules are changed for any active debug session. @@ -48,6 +57,15 @@ export class DebugSessionTracker onDidChangeModules: vscode.Event = this.modulesChanged.event; + /** Fired when a debug session is initialized. */ + onDidReceiveSessionCapabilities: + vscode.Event<[ vscode.DebugSession, LLDBDapCapabilities ]> = + this.sessionReceivedCapabilities.event; + + /** Fired when a debug session is exiting. */ + onDidExitSession: vscode.Event = + this.sessionExited.event; + constructor(private logger: vscode.LogOutputChannel) { this.onDidChangeModules(this.moduleChangedListener, this); vscode.debug.onDidChangeActiveDebugSession((session) => @@ -146,6 +164,10 @@ export class DebugSessionTracker this.logger.info( `Session "${session.name}" exited with code ${exitCode}`, ); + + this.sessionExited.fire(session); + } else if (isEvent(message, "capabilities")) { + this.sessionReceivedCapabilities.fire([ session, message.body.capabilities ]); } } } diff --git a/lldb/tools/lldb-dap/src-ts/extension.ts b/lldb/tools/lldb-dap/src-ts/extension.ts index 4b7a35e6944c6..7119cba972fa4 100644 --- a/lldb/tools/lldb-dap/src-ts/extension.ts +++ b/lldb/tools/lldb-dap/src-ts/extension.ts @@ -12,6 +12,7 @@ import { ModuleProperty, } from "./ui/modules-data-provider"; import { LogFilePathProvider } from "./logging"; +import { SymbolsProvider } from "./ui/symbols-provider"; /** * This class represents the extension and manages its life cycle. Other extensions @@ -19,6 +20,7 @@ import { LogFilePathProvider } from "./logging"; */ export class LLDBDapExtension extends DisposableContext { constructor( + context: vscode.ExtensionContext, logger: vscode.LogOutputChannel, logFilePath: LogFilePathProvider, outputChannel: vscode.OutputChannel, @@ -52,10 +54,12 @@ export class LLDBDapExtension extends DisposableContext { vscode.window.registerUriHandler(new LaunchUriHandler()), ); - vscode.commands.registerCommand( + this.pushSubscription(vscode.commands.registerCommand( "lldb-dap.modules.copyProperty", (node: ModuleProperty) => vscode.env.clipboard.writeText(node.value), - ); + )); + + this.pushSubscription(new SymbolsProvider(sessionTracker, context)); } } @@ -67,7 +71,7 @@ export async function activate(context: vscode.ExtensionContext) { outputChannel.info("LLDB-DAP extension activating..."); const logFilePath = new LogFilePathProvider(context, outputChannel); context.subscriptions.push( - new LLDBDapExtension(outputChannel, logFilePath, outputChannel), + new LLDBDapExtension(context, outputChannel, logFilePath, outputChannel), ); outputChannel.info("LLDB-DAP extension activated"); } diff --git a/lldb/tools/lldb-dap/src-ts/index.d.ts b/lldb/tools/lldb-dap/src-ts/index.d.ts new file mode 100644 index 0000000000000..d4618f44dee7b --- /dev/null +++ b/lldb/tools/lldb-dap/src-ts/index.d.ts @@ -0,0 +1,14 @@ +export {}; + +/// The symbol type we get from the lldb-dap server +export declare interface SymbolType { + id: number; + isDebug: boolean; + isSynthetic: boolean; + isExternal: boolean; + type: string; + fileAddress: number; + loadAddress?: number; + size: number; + name: string; +} diff --git a/lldb/tools/lldb-dap/src-ts/lldb-dap-server.ts b/lldb/tools/lldb-dap/src-ts/lldb-dap-server.ts index 5f9d8efdcb3a3..300b12d1cce1b 100644 --- a/lldb/tools/lldb-dap/src-ts/lldb-dap-server.ts +++ b/lldb/tools/lldb-dap/src-ts/lldb-dap-server.ts @@ -11,6 +11,7 @@ import * as vscode from "vscode"; export class LLDBDapServer implements vscode.Disposable { private serverProcess?: child_process.ChildProcessWithoutNullStreams; private serverInfo?: Promise<{ host: string; port: number }>; + private serverSpawnInfo?: string[]; constructor() { vscode.commands.registerCommand( @@ -34,7 +35,7 @@ export class LLDBDapServer implements vscode.Disposable { options?: child_process.SpawnOptionsWithoutStdio, ): Promise<{ host: string; port: number } | undefined> { const dapArgs = [...args, "--connection", "listen://localhost:0"]; - if (!(await this.shouldContinueStartup(dapPath, dapArgs))) { + if (!(await this.shouldContinueStartup(dapPath, dapArgs, options?.env))) { return undefined; } @@ -70,6 +71,7 @@ export class LLDBDapServer implements vscode.Disposable { } }); this.serverProcess = process; + this.serverSpawnInfo = this.getSpawnInfo(dapPath, dapArgs, options?.env); }); return this.serverInfo; } @@ -85,12 +87,14 @@ export class LLDBDapServer implements vscode.Disposable { private async shouldContinueStartup( dapPath: string, args: string[], + env: NodeJS.ProcessEnv | { [key: string]: string } | undefined, ): Promise { - if (!this.serverProcess || !this.serverInfo) { + if (!this.serverProcess || !this.serverInfo || !this.serverSpawnInfo) { return true; } - if (isDeepStrictEqual(this.serverProcess.spawnargs, [dapPath, ...args])) { + const newSpawnInfo = this.getSpawnInfo(dapPath, args, env); + if (isDeepStrictEqual(this.serverSpawnInfo, newSpawnInfo)) { return true; } @@ -102,11 +106,11 @@ export class LLDBDapServer implements vscode.Disposable { The previous lldb-dap server was started with: -${this.serverProcess.spawnargs.join(" ")} +${this.serverSpawnInfo.join(" ")} The new lldb-dap server will be started with: -${dapPath} ${args.join(" ")} +${newSpawnInfo.join(" ")} Restarting the server will interrupt any existing debug sessions and start a new server.`, }, @@ -143,4 +147,18 @@ Restarting the server will interrupt any existing debug sessions and start a new this.serverInfo = undefined; } } + + getSpawnInfo( + path: string, + args: string[], + env: NodeJS.ProcessEnv | { [key: string]: string } | undefined, + ): string[] { + return [ + path, + ...args, + ...Object.entries(env ?? {}).map( + (entry) => String(entry[0]) + "=" + String(entry[1]), + ), + ]; + } } diff --git a/lldb/tools/lldb-dap/src-ts/ui/modules-data-provider.ts b/lldb/tools/lldb-dap/src-ts/ui/modules-data-provider.ts index d0fb9270c734f..96343cb0a8da6 100644 --- a/lldb/tools/lldb-dap/src-ts/ui/modules-data-provider.ts +++ b/lldb/tools/lldb-dap/src-ts/ui/modules-data-provider.ts @@ -19,6 +19,7 @@ class ModuleItem extends vscode.TreeItem { constructor(module: DebugProtocol.Module) { super(module.name, vscode.TreeItemCollapsibleState.Collapsed); this.description = module.symbolStatus; + this.contextValue = "module"; } static getProperties(module: DebugProtocol.Module): ModuleProperty[] { diff --git a/lldb/tools/lldb-dap/src-ts/ui/symbols-provider.ts b/lldb/tools/lldb-dap/src-ts/ui/symbols-provider.ts new file mode 100644 index 0000000000000..84b9387ffe49f --- /dev/null +++ b/lldb/tools/lldb-dap/src-ts/ui/symbols-provider.ts @@ -0,0 +1,127 @@ +import * as vscode from "vscode"; +import { DebugProtocol } from "@vscode/debugprotocol"; + +import { DebugSessionTracker } from "../debug-session-tracker"; +import { DisposableContext } from "../disposable-context"; + +import { SymbolType } from ".."; +import { getSymbolsTableHTMLContent } from "./symbols-webview-html"; +import { getDefaultConfigKey } from "../debug-configuration-provider"; + +export class SymbolsProvider extends DisposableContext { + constructor( + private readonly tracker: DebugSessionTracker, + private readonly extensionContext: vscode.ExtensionContext, + ) { + super(); + + this.pushSubscription(vscode.commands.registerCommand( + "lldb-dap.debug.showSymbols", + () => { + const session = vscode.debug.activeDebugSession; + if (!session) return; + + this.SelectModuleAndShowSymbols(session); + }, + )); + + this.pushSubscription(vscode.commands.registerCommand( + "lldb-dap.modules.showSymbols", + (moduleItem: DebugProtocol.Module) => { + const session = vscode.debug.activeDebugSession; + if (!session) return; + + this.showSymbolsForModule(session, moduleItem); + }, + )); + + this.tracker.onDidReceiveSessionCapabilities(([ _session, capabilities ]) => { + if (capabilities.supportsModuleSymbolsRequest) { + vscode.commands.executeCommand( + "setContext", "lldb-dap.supportsModuleSymbolsRequest", true); + } + }); + + this.tracker.onDidExitSession((_session) => { + vscode.commands.executeCommand("setContext", "lldb-dap.supportsModuleSymbolsRequest", false); + }); + } + + private async SelectModuleAndShowSymbols(session: vscode.DebugSession) { + const modules = this.tracker.debugSessionModules(session); + if (!modules || modules.length === 0) { + return; + } + + // Let the user select a module to show symbols for + const selectedModule = await vscode.window.showQuickPick(modules.map(m => new ModuleQuickPickItem(m)), { + placeHolder: "Select a module to show symbols for" + }); + if (!selectedModule) { + return; + } + + this.showSymbolsForModule(session, selectedModule.module); + } + + private async showSymbolsForModule(session: vscode.DebugSession, module: DebugProtocol.Module) { + try { + const symbols = await this.getSymbolsForModule(session, module.id.toString()); + this.showSymbolsInNewTab(module.name.toString(), symbols); + } catch (error) { + if (error instanceof Error) { + vscode.window.showErrorMessage("Failed to retrieve symbols: " + error.message); + } else { + vscode.window.showErrorMessage("Failed to retrieve symbols due to an unknown error."); + } + + return; + } + } + + private async getSymbolsForModule(session: vscode.DebugSession, moduleId: string): Promise { + const symbols_response: { symbols: Array } = await session.customRequest("__lldb_moduleSymbols", { moduleId, moduleName: '' }); + return symbols_response?.symbols || []; + } + + private async showSymbolsInNewTab(moduleName: string, symbols: SymbolType[]) { + const panel = vscode.window.createWebviewPanel( + "lldb-dap.symbols", + `Symbols for ${moduleName}`, + vscode.ViewColumn.Active, + { + enableScripts: true, + localResourceRoots: [ + this.getExtensionResourcePath() + ] + } + ); + + let tabulatorJsFilename = "tabulator_simple.min.css"; + if (vscode.window.activeColorTheme.kind === vscode.ColorThemeKind.Dark || vscode.window.activeColorTheme.kind === vscode.ColorThemeKind.HighContrast) { + tabulatorJsFilename = "tabulator_midnight.min.css"; + } + const tabulatorCssPath = panel.webview.asWebviewUri(vscode.Uri.joinPath(this.getExtensionResourcePath(), tabulatorJsFilename)); + const tabulatorJsPath = panel.webview.asWebviewUri(vscode.Uri.joinPath(this.getExtensionResourcePath(), "tabulator.min.js")); + const symbolsTableScriptPath = panel.webview.asWebviewUri(vscode.Uri.joinPath(this.getExtensionResourcePath(), "symbols-table-view.js")); + + panel.webview.html = getSymbolsTableHTMLContent(tabulatorJsPath, tabulatorCssPath, symbolsTableScriptPath); + panel.webview.postMessage({ command: "updateSymbols", symbols: symbols }); + } + + private getExtensionResourcePath(): vscode.Uri { + return vscode.Uri.joinPath(this.extensionContext.extensionUri, "out", "webview"); + } +} + +class ModuleQuickPickItem implements vscode.QuickPickItem { + constructor(public readonly module: DebugProtocol.Module) {} + + get label(): string { + return this.module.name; + } + + get description(): string { + return this.module.id.toString(); + } +} diff --git a/lldb/tools/lldb-dap/src-ts/ui/symbols-webview-html.ts b/lldb/tools/lldb-dap/src-ts/ui/symbols-webview-html.ts new file mode 100644 index 0000000000000..c00e0d462569a --- /dev/null +++ b/lldb/tools/lldb-dap/src-ts/ui/symbols-webview-html.ts @@ -0,0 +1,67 @@ +import * as vscode from "vscode"; + +export function getSymbolsTableHTMLContent(tabulatorJsPath: vscode.Uri, tabulatorCssPath: vscode.Uri, symbolsTableScriptPath: vscode.Uri): string { + return ` + + + + + + + +
+ + + +`; +} \ No newline at end of file diff --git a/lldb/tools/lldb-dap/src-ts/webview/symbols-table-view.ts b/lldb/tools/lldb-dap/src-ts/webview/symbols-table-view.ts new file mode 100644 index 0000000000000..9d346818e384a --- /dev/null +++ b/lldb/tools/lldb-dap/src-ts/webview/symbols-table-view.ts @@ -0,0 +1,115 @@ +import type { CellComponent, ColumnDefinition } from "tabulator-tables"; +import type { SymbolType } from ".." + +/// SVG from https://github.com/olifolkerd/tabulator/blob/master/src/js/modules/Format/defaults/formatters/tickCross.js +/// but with the default font color. +/// hopefully in the future we can set the color as parameter: https://github.com/olifolkerd/tabulator/pull/4791 +const TICK_ELEMENT = ``; + +function getTabulatorHexaFormatter(padding: number): (cell: CellComponent) => string { + return (cell: CellComponent) => { + const val = cell.getValue(); + if (val === undefined || val === null) { + return ""; + } + + return val !== undefined ? "0x" + val.toString(16).toLowerCase().padStart(padding, "0") : ""; + }; +} + +const SYMBOL_TABLE_COLUMNS: ColumnDefinition[] = [ + { title: "ID", field: "id", headerTooltip: true, sorter: "number", widthGrow: 0.6 }, + { + title: "Name", + field: "name", + headerTooltip: true, + sorter: "string", + widthGrow: 2.5, + minWidth: 200, + tooltip : (_event: MouseEvent, cell: CellComponent) => { + const rowData = cell.getRow().getData(); + return rowData.name; + } + }, + { + title: "Debug", + field: "isDebug", + headerTooltip: true, + hozAlign: "center", + widthGrow: 0.8, + formatter: "tickCross", + formatterParams: { + tickElement: TICK_ELEMENT, + crossElement: false, + } + }, + { + title: "Synthetic", + field: "isSynthetic", + headerTooltip: true, + hozAlign: "center", + widthGrow: 0.8, + formatter: "tickCross", + formatterParams: { + tickElement: TICK_ELEMENT, + crossElement: false, + } + }, + { + title: "External", + field: "isExternal", + headerTooltip: true, + hozAlign: "center", + widthGrow: 0.8, + formatter: "tickCross", + formatterParams: { + tickElement: TICK_ELEMENT, + crossElement: false, + } + }, + { title: "Type", field: "type", sorter: "string" }, + { + title: "File Address", + field: "fileAddress", + headerTooltip: true, + sorter: "number", + widthGrow : 1.25, + formatter: getTabulatorHexaFormatter(16), + }, + { + title: "Load Address", + field: "loadAddress", + headerTooltip: true, + sorter: "number", + widthGrow : 1.25, + formatter: getTabulatorHexaFormatter(16), + }, + { title: "Size", field: "size", headerTooltip: true, sorter: "number", formatter: getTabulatorHexaFormatter(8) }, +]; + +const vscode = acquireVsCodeApi(); +const previousState: any = vscode.getState(); + +declare const Tabulator: any; // HACK: real definition comes from tabulator.min.js +const SYMBOLS_TABLE = new Tabulator("#symbols-table", { + height: "100vh", + columns: SYMBOL_TABLE_COLUMNS, + layout: "fitColumns", + selectableRows: false, + data: previousState?.symbols || [], +}); + +function updateSymbolsTable(symbols: SymbolType[]) { + SYMBOLS_TABLE.setData(symbols); +} + +window.addEventListener("message", (event: MessageEvent) => { + const message = event.data; + switch (message.command) { + case "updateSymbols": + vscode.setState({ symbols: message.symbols }); + updateSymbolsTable(message.symbols); + break; + } +}); + diff --git a/lldb/tools/lldb-dap/src-ts/webview/tsconfig.json b/lldb/tools/lldb-dap/src-ts/webview/tsconfig.json new file mode 100644 index 0000000000000..cfe64fc4b989f --- /dev/null +++ b/lldb/tools/lldb-dap/src-ts/webview/tsconfig.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "moduleResolution": "node", + "module": "esnext", + "outDir": "out", + "rootDir": ".", + "sourceMap": true, + "strict": true, + "noEmit": true, + "target": "es2017" + }, + "include": [ + "./" + ], +} diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index 8bba4162aa7bf..b74085f25f4e2 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" +#include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -284,7 +285,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, }); std::condition_variable dap_sessions_condition; std::mutex dap_sessions_mutex; - std::map dap_sessions; + std::map dap_sessions; unsigned int clientCount = 0; auto handle = listener->Accept(g_loop, [=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions, @@ -300,8 +301,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, std::thread client([=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions]() { llvm::set_thread_name(client_name + ".runloop"); + MainLoop loop; Transport transport(client_name, log, io, io); - DAP dap(log, default_repl_mode, pre_init_commands, transport); + DAP dap(log, default_repl_mode, pre_init_commands, client_name, transport, + loop); if (auto Err = dap.ConfigureIO()) { llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), @@ -311,7 +314,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, { std::scoped_lock lock(dap_sessions_mutex); - dap_sessions[io.get()] = &dap; + dap_sessions[&loop] = &dap; } if (auto Err = dap.Loop()) { @@ -322,7 +325,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, DAP_LOG(log, "({0}) client disconnected", client_name); std::unique_lock lock(dap_sessions_mutex); - dap_sessions.erase(io.get()); + dap_sessions.erase(&loop); std::notify_all_at_thread_exit(dap_sessions_condition, std::move(lock)); }); client.detach(); @@ -344,13 +347,14 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, bool client_failed = false; { std::scoped_lock lock(dap_sessions_mutex); - for (auto [sock, dap] : dap_sessions) { + for (auto [loop, dap] : dap_sessions) { if (llvm::Error error = dap->Disconnect()) { client_failed = true; - llvm::errs() << "DAP client " << dap->transport.GetClientName() - << " disconnected failed: " - << llvm::toString(std::move(error)) << "\n"; + llvm::WithColor::error() << "DAP client disconnected failed: " + << llvm::toString(std::move(error)) << "\n"; } + loop->AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); } } @@ -550,8 +554,10 @@ int main(int argc, char *argv[]) { stdout_fd, File::eOpenOptionWriteOnly, NativeFile::Unowned); constexpr llvm::StringLiteral client_name = "stdio"; + MainLoop loop; Transport transport(client_name, log.get(), input, output); - DAP dap(log.get(), default_repl_mode, pre_init_commands, transport); + DAP dap(log.get(), default_repl_mode, pre_init_commands, client_name, + transport, loop); // stdout/stderr redirection to the IDE's console if (auto Err = dap.ConfigureIO(stdout, stderr)) { diff --git a/lldb/tools/lldb-dap/tsconfig.json b/lldb/tools/lldb-dap/tsconfig.json index 2092148888904..06a484a1fc263 100644 --- a/lldb/tools/lldb-dap/tsconfig.json +++ b/lldb/tools/lldb-dap/tsconfig.json @@ -1,5 +1,6 @@ { "compilerOptions": { + "moduleResolution": "node", "module": "commonjs", "outDir": "out", "rootDir": "src-ts", @@ -12,5 +13,6 @@ ], "exclude": [ "node_modules", + "src-ts/webview", ] } diff --git a/lldb/tools/lldb-mcp/CMakeLists.txt b/lldb/tools/lldb-mcp/CMakeLists.txt new file mode 100644 index 0000000000000..7fe3301ab3081 --- /dev/null +++ b/lldb/tools/lldb-mcp/CMakeLists.txt @@ -0,0 +1,33 @@ +add_lldb_tool(lldb-mcp + lldb-mcp.cpp + + LINK_COMPONENTS + Option + Support + LINK_LIBS + liblldb + lldbHost + lldbProtocolMCP + ) + +if(APPLE) + configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/lldb-mcp-Info.plist.in + ${CMAKE_CURRENT_BINARY_DIR}/lldb-mcp-Info.plist + ) + target_link_options(lldb-mcp + PRIVATE LINKER:-sectcreate,__TEXT,__info_plist,${CMAKE_CURRENT_BINARY_DIR}/lldb-mcp-Info.plist) +endif() + +if(LLDB_BUILD_FRAMEWORK) + # In the build-tree, we know the exact path to the framework directory. + # The installed framework can be in different locations. + lldb_setup_rpaths(lldb-mcp + BUILD_RPATH + "${LLDB_FRAMEWORK_ABSOLUTE_BUILD_DIR}" + INSTALL_RPATH + "@loader_path/../../../SharedFrameworks" + "@loader_path/../../System/Library/PrivateFrameworks" + "@loader_path/../../Library/PrivateFrameworks" + ) +endif() diff --git a/lldb/tools/lldb-mcp/lldb-mcp-Info.plist.in b/lldb/tools/lldb-mcp/lldb-mcp-Info.plist.in new file mode 100644 index 0000000000000..4dc3ddd912808 --- /dev/null +++ b/lldb/tools/lldb-mcp/lldb-mcp-Info.plist.in @@ -0,0 +1,21 @@ + + + + + CFBundleDevelopmentRegion + English + CFBundleIdentifier + com.apple.lldb-mcp + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + lldb-mcp + CFBundleVersion + ${LLDB_VERSION} + SecTaskAccess + + allowed + debug + + + diff --git a/lldb/tools/lldb-mcp/lldb-mcp.cpp b/lldb/tools/lldb-mcp/lldb-mcp.cpp new file mode 100644 index 0000000000000..1f82af94820da --- /dev/null +++ b/lldb/tools/lldb-mcp/lldb-mcp.cpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Host/Config.h" +#include "lldb/Host/File.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Server.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/WithColor.h" + +#if defined(_WIN32) +#include +#endif + +using namespace lldb_protocol::mcp; + +using lldb_private::File; +using lldb_private::MainLoop; +using lldb_private::MainLoopBase; +using lldb_private::NativeFile; + +static constexpr llvm::StringLiteral kName = "lldb-mcp"; +static constexpr llvm::StringLiteral kVersion = "0.1.0"; + +int main(int argc, char *argv[]) { + llvm::InitLLVM IL(argc, argv, /*InstallPipeSignalExitHandler=*/false); +#if !defined(__APPLE__) + llvm::setBugReportMsg("PLEASE submit a bug report to " LLDB_BUG_REPORT_URL + " and include the crash backtrace.\n"); +#else + llvm::setBugReportMsg("PLEASE submit a bug report to " LLDB_BUG_REPORT_URL + " and include the crash report from " + "~/Library/Logs/DiagnosticReports/.\n"); +#endif + +#if defined(_WIN32) + // Windows opens stdout and stdin in text mode which converts \n to 13,10 + // while the value is just 10 on Darwin/Linux. Setting the file mode to + // binary fixes this. + int result = _setmode(fileno(stdout), _O_BINARY); + assert(result); + result = _setmode(fileno(stdin), _O_BINARY); + UNUSED_IF_ASSERT_DISABLED(result); + assert(result); +#endif + + lldb::IOObjectSP input = std::make_shared( + fileno(stdin), File::eOpenOptionReadOnly, NativeFile::Unowned); + + lldb::IOObjectSP output = std::make_shared( + fileno(stdout), File::eOpenOptionWriteOnly, NativeFile::Unowned); + + constexpr llvm::StringLiteral client_name = "stdio"; + static MainLoop loop; + + llvm::sys::SetInterruptFunction([]() { + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + }); + + auto transport_up = std::make_unique( + input, output, std::string(client_name), + [&](llvm::StringRef message) { llvm::errs() << message << '\n'; }); + + auto instance_up = std::make_unique( + std::string(kName), std::string(kVersion), std::move(transport_up), loop); + + if (llvm::Error error = instance_up->Run()) { + llvm::logAllUnhandledErrors(std::move(error), llvm::WithColor::error(), + "MCP error: "); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/lldb/unittests/CMakeLists.txt b/lldb/unittests/CMakeLists.txt index 8a20839a37469..6efd0ca1a5b41 100644 --- a/lldb/unittests/CMakeLists.txt +++ b/lldb/unittests/CMakeLists.txt @@ -86,10 +86,6 @@ add_subdirectory(Utility) add_subdirectory(ValueObject) add_subdirectory(tools) -if(LLDB_ENABLE_PROTOCOL_SERVERS) - add_subdirectory(ProtocolServer) -endif() - if(LLDB_CAN_USE_DEBUGSERVER AND LLDB_TOOL_DEBUGSERVER_BUILD AND NOT LLDB_USE_SYSTEM_DEBUGSERVER) add_subdirectory(debugserver) endif() diff --git a/lldb/unittests/DAP/CMakeLists.txt b/lldb/unittests/DAP/CMakeLists.txt index 156cd625546bd..716159b454231 100644 --- a/lldb/unittests/DAP/CMakeLists.txt +++ b/lldb/unittests/DAP/CMakeLists.txt @@ -1,6 +1,7 @@ add_lldb_unittest(DAPTests DAPErrorTest.cpp DAPTest.cpp + DAPTypesTest.cpp FifoFilesTest.cpp Handler/DisconnectTest.cpp Handler/ContinueTest.cpp diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 40ffaf87c9c45..d5a9591ad0a43 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -9,11 +9,9 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" #include "TestBase.h" -#include "Transport.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" -#include -#include #include using namespace llvm; @@ -21,6 +19,7 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; +using namespace testing; class DAPTest : public TransportBase {}; @@ -29,11 +28,13 @@ TEST_F(DAPTest, SendProtocolMessages) { /*log=*/nullptr, /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/{}, - /*transport=*/*to_dap, + /*client_name=*/"test_client", + /*transport=*/*transport, + /*loop=*/loop, }; dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - ASSERT_THAT_EXPECTED( - from_dap->Read(std::chrono::milliseconds(1)), - HasValue(testing::VariantWith(testing::FieldsAre( - /*event=*/"my-event", /*body=*/std::nullopt)))); + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt))); + ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); } diff --git a/lldb/unittests/DAP/DAPTypesTest.cpp b/lldb/unittests/DAP/DAPTypesTest.cpp new file mode 100644 index 0000000000000..f398c54b724a0 --- /dev/null +++ b/lldb/unittests/DAP/DAPTypesTest.cpp @@ -0,0 +1,60 @@ +//===-- DAPTypesTest.cpp ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol/DAPTypes.h" +#include "TestingSupport/TestUtilities.h" +#include "lldb/lldb-enumerations.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; +using namespace lldb; +using namespace lldb_dap; +using namespace lldb_dap::protocol; +using lldb_private::roundtripJSON; + +TEST(DAPTypesTest, SourceLLDBData) { + SourceLLDBData source_data; + source_data.persistenceData = + PersistenceData{"module_path123", "symbol_name456"}; + + llvm::Expected deserialized_data = roundtripJSON(source_data); + ASSERT_THAT_EXPECTED(deserialized_data, llvm::Succeeded()); + + EXPECT_EQ(source_data.persistenceData->module_path, + deserialized_data->persistenceData->module_path); + EXPECT_EQ(source_data.persistenceData->symbol_name, + deserialized_data->persistenceData->symbol_name); +} + +TEST(DAPTypesTest, DAPSymbol) { + Symbol symbol; + symbol.id = 42; + symbol.isDebug = true; + symbol.isExternal = false; + symbol.isSynthetic = true; + symbol.type = lldb::eSymbolTypeTrampoline; + symbol.fileAddress = 0x12345678; + symbol.loadAddress = 0x87654321; + symbol.size = 64; + symbol.name = "testSymbol"; + + llvm::Expected deserialized_symbol = roundtripJSON(symbol); + ASSERT_THAT_EXPECTED(deserialized_symbol, llvm::Succeeded()); + + EXPECT_EQ(symbol.id, deserialized_symbol->id); + EXPECT_EQ(symbol.isDebug, deserialized_symbol->isDebug); + EXPECT_EQ(symbol.isExternal, deserialized_symbol->isExternal); + EXPECT_EQ(symbol.isSynthetic, deserialized_symbol->isSynthetic); + EXPECT_EQ(symbol.type, deserialized_symbol->type); + EXPECT_EQ(symbol.fileAddress, deserialized_symbol->fileAddress); + EXPECT_EQ(symbol.loadAddress, deserialized_symbol->loadAddress); + EXPECT_EQ(symbol.size, deserialized_symbol->size); + EXPECT_EQ(symbol.name, deserialized_symbol->name); +} diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index 0546aeb154d50..c6ff1f90b01d5 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -23,18 +23,15 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; +using testing::_; class DisconnectRequestHandlerTest : public DAPTestBase {}; TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); - EXPECT_FALSE(dap->disconnecting); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); - EXPECT_TRUE(dap->disconnecting); - std::vector messages = DrainOutput(); - EXPECT_THAT(messages, - testing::Contains(testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); + EXPECT_CALL(client, Received(IsEvent("terminated", _))); + RunOnce(); } TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { @@ -47,17 +44,14 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { DisconnectRequestHandler handler(*dap); - EXPECT_FALSE(dap->disconnecting); dap->configuration.terminateCommands = {"?script print(1)", "script print(2)"}; EXPECT_EQ(dap->target.GetProcess().GetState(), lldb::eStateStopped); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); - EXPECT_TRUE(dap->disconnecting); - std::vector messages = DrainOutput(); - EXPECT_THAT(messages, testing::ElementsAre( - OutputMatcher("Running terminateCommands:\n"), - OutputMatcher("(lldb) script print(2)\n"), - OutputMatcher("2\n"), - testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); + EXPECT_CALL(client, Received(Output("1\n"))); + EXPECT_CALL(client, Received(Output("2\n"))).Times(2); + EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n"))); + EXPECT_CALL(client, Received(Output("Running terminateCommands:\n"))); + EXPECT_CALL(client, Received(IsEvent("terminated", _))); + RunOnce(); } diff --git a/lldb/unittests/DAP/ProtocolTypesTest.cpp b/lldb/unittests/DAP/ProtocolTypesTest.cpp index 4aab2dc223134..c5d47fcb08da4 100644 --- a/lldb/unittests/DAP/ProtocolTypesTest.cpp +++ b/lldb/unittests/DAP/ProtocolTypesTest.cpp @@ -1004,3 +1004,72 @@ TEST(ProtocolTypesTest, VariablesResponseBody) { ASSERT_THAT_EXPECTED(expected, llvm::Succeeded()); EXPECT_EQ(pp(*expected), pp(response)); } + +TEST(ProtocolTypesTest, CompletionItem) { + CompletionItem item; + item.label = "label"; + item.text = "text"; + item.sortText = "sortText"; + item.detail = "detail"; + item.type = eCompletionItemTypeConstructor; + item.start = 1; + item.length = 3; + item.selectionStart = 4; + item.selectionLength = 8; + + const StringRef json = R"({ + "detail": "detail", + "label": "label", + "length": 3, + "selectionLength": 8, + "selectionStart": 4, + "sortText": "sortText", + "start": 1, + "text": "text", + "type": "constructor" +})"; + + EXPECT_EQ(pp(Value(item)), json); + EXPECT_THAT_EXPECTED(json::parse(json), HasValue(Value(item))); +} + +TEST(ProtocolTypesTest, CompletionsArguments) { + llvm::Expected expected = + parse(R"({ + "column": 8, + "frameId": 7, + "line": 9, + "text": "abc" + })"); + ASSERT_THAT_EXPECTED(expected, llvm::Succeeded()); + EXPECT_EQ(expected->frameId, 7u); + EXPECT_EQ(expected->text, "abc"); + EXPECT_EQ(expected->column, 8); + EXPECT_EQ(expected->line, 9); + + // Check required keys. + EXPECT_THAT_EXPECTED(parse(R"({})"), + FailedWithMessage("missing value at (root).text")); + EXPECT_THAT_EXPECTED(parse(R"({"text":"abc"})"), + FailedWithMessage("missing value at (root).column")); +} + +TEST(ProtocolTypesTest, CompletionsResponseBody) { + CompletionItem item; + item.label = "label"; + item.text = "text"; + item.detail = "detail"; + CompletionsResponseBody response{{item}}; + + Expected expected = json::parse(R"({ + "targets": [ + { + "detail": "detail", + "label": "label", + "text": "text" + } + ] + })"); + ASSERT_THAT_EXPECTED(expected, llvm::Succeeded()); + EXPECT_EQ(pp(*expected), pp(response)); +} diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index d5d36158d68e0..03b41212083ac 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -7,17 +7,19 @@ //===----------------------------------------------------------------------===// #include "TestBase.h" -#include "Protocol/ProtocolBase.h" +#include "DAPLog.h" #include "TestingSupport/TestUtilities.h" #include "lldb/API/SBDefines.h" #include "lldb/API/SBStructuredData.h" -#include "lldb/Host/File.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Host/Pipe.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include #include +#include using namespace llvm; using namespace lldb; @@ -25,36 +27,36 @@ using namespace lldb_dap; using namespace lldb_dap::protocol; using namespace lldb_dap_tests; using lldb_private::File; -using lldb_private::NativeFile; +using lldb_private::FileSpec; +using lldb_private::FileSystem; +using lldb_private::MainLoop; using lldb_private::Pipe; -void TransportBase::SetUp() { - PipePairTest::SetUp(); - to_dap = std::make_unique( - "to_dap", nullptr, - std::make_shared(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); - from_dap = std::make_unique( - "from_dap", nullptr, - std::make_shared(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); +Expected +TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { + Expected dummy_file = FileSystem::Instance().Open( + FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); + if (!dummy_file) + return dummy_file.takeError(); + m_dummy_file = std::move(*dummy_file); + lldb_private::Status status; + auto handle = loop.RegisterReadObject( + m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); + if (status.Fail()) + return status.takeError(); + return handle; } void DAPTestBase::SetUp() { TransportBase::SetUp(); + std::error_code EC; + log = std::make_unique("-", EC); dap = std::make_unique( - /*log=*/nullptr, + /*log=*/log.get(), /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/std::vector(), - /*transport=*/*to_dap); + /*client_name=*/"test_client", + /*transport=*/*transport, /*loop=*/loop); } void DAPTestBase::TearDown() { @@ -70,7 +72,7 @@ void DAPTestBase::SetUpTestSuite() { } void DAPTestBase::TeatUpTestSuite() { SBDebugger::Terminate(); } -bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { +bool DAPTestBase::GetDebuggerSupportsTarget(StringRef platform) { EXPECT_TRUE(dap->debugger); lldb::SBStructuredData data = dap->debugger.GetBuildConfiguration() @@ -79,7 +81,7 @@ bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { for (size_t i = 0; i < data.GetSize(); i++) { char buf[100] = {0}; size_t size = data.GetItemAtIndex(i).GetStringValue(buf, sizeof(buf)); - if (llvm::StringRef(buf, size) == platform) + if (StringRef(buf, size) == platform) return true; } @@ -89,6 +91,24 @@ bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { void DAPTestBase::CreateDebugger() { dap->debugger = lldb::SBDebugger::Create(); ASSERT_TRUE(dap->debugger); + dap->target = dap->debugger.GetDummyTarget(); + + Expected dev_null = FileSystem::Instance().Open( + FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); + ASSERT_THAT_EXPECTED(dev_null, Succeeded()); + lldb::FileSP dev_null_sp = std::move(*dev_null); + + std::FILE *dev_null_stream = dev_null_sp->GetStream(); + ASSERT_THAT_ERROR(dap->ConfigureIO(dev_null_stream, dev_null_stream), + Succeeded()); + + dap->debugger.SetInputFile(dap->in); + auto out_fd = dap->out.GetWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(out_fd, Succeeded()); + dap->debugger.SetOutputFile(lldb::SBFile(*out_fd, "w", false)); + auto err_fd = dap->out.GetWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(err_fd, Succeeded()); + dap->debugger.SetErrorFile(lldb::SBFile(*err_fd, "w", false)); } void DAPTestBase::LoadCore() { @@ -112,18 +132,3 @@ void DAPTestBase::LoadCore() { SBProcess process = dap->target.LoadCore(this->core->TmpName.data()); ASSERT_TRUE(process); } - -std::vector DAPTestBase::DrainOutput() { - std::vector msgs; - output.CloseWriteFileDescriptor(); - while (true) { - Expected next = - from_dap->Read(std::chrono::milliseconds(1)); - if (!next) { - consumeError(next.takeError()); - break; - } - msgs.push_back(*next); - } - return msgs; -} diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index 50884b1d7feb9..c19eead4e37e7 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -8,35 +8,109 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" -#include "TestingSupport/Host/PipeTestUtilities.h" -#include "Transport.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include namespace lldb_dap_tests { +class TestTransport final + : public lldb_private::Transport { +public: + using Message = lldb_private::Transport::Message; + + TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) + : m_loop(loop), m_handler(handler) {} + + llvm::Error Send(const lldb_dap::protocol::Event &e) override { + m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { + this->m_handler.Received(e); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const lldb_dap::protocol::Request &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.Received(r); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const lldb_dap::protocol::Response &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.Received(r); + }); + return llvm::Error::success(); + } + + llvm::Expected + RegisterMessageHandler(lldb_private::MainLoop &loop, + MessageHandler &handler) override; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; + +private: + lldb_private::MainLoop &m_loop; + MessageHandler &m_handler; + lldb::FileSP m_dummy_file; +}; + /// A base class for tests that need transport configured for communicating DAP /// messages. -class TransportBase : public PipePairTest { +class TransportBase : public testing::Test { protected: - std::unique_ptr to_dap; - std::unique_ptr from_dap; + lldb_private::SubsystemRAII + subsystems; + lldb_private::MainLoop loop; + std::unique_ptr transport; + MockMessageHandler + client; - void SetUp() override; + void SetUp() override { + transport = std::make_unique(loop, client); + } }; +/// A matcher for a DAP event. +template +inline testing::Matcher +IsEvent(const M1 &m1, const M2 &m2) { + return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1), + testing::Field(&lldb_dap::protocol::Event::body, m2)); +} + /// Matches an "output" event. -inline auto OutputMatcher(const llvm::StringRef output, - const llvm::StringRef category = "console") { - return testing::VariantWith(testing::FieldsAre( - /*event=*/"output", /*body=*/testing::Optional( - llvm::json::Object{{"category", category}, {"output", output}}))); +inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") { + return IsEvent("output", + testing::Optional(llvm::json::Value( + llvm::json::Object{{"category", cat}, {"output", o}}))); } /// A base class for tests that interact with a `lldb_dap::DAP` instance. class DAPTestBase : public TransportBase { protected: + std::unique_ptr log; std::unique_ptr dap; std::optional core; std::optional binary; @@ -53,9 +127,11 @@ class DAPTestBase : public TransportBase { void CreateDebugger(); void LoadCore(); - /// Closes the DAP output pipe and returns the remaining protocol messages in - /// the buffer. - std::vector DrainOutput(); + void RunOnce() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); + } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 2f0846471688c..445674f402252 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -1,4 +1,4 @@ -//===-- JSONTransportTest.cpp ---------------------------------------------===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,16 +7,143 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/Host/PipeTestUtilities.h" #include "lldb/Host/File.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "lldb/Utility/Log.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include +#include +#include using namespace llvm; using namespace lldb_private; +using testing::_; +using testing::HasSubstr; +using testing::InSequence; namespace { -template class JSONTransportTest : public PipePairTest { + +namespace test_protocol { + +struct Req { + std::string name; +}; +json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +bool fromJSON(const json::Value &V, Req &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("req", T.name); +} +bool operator==(const Req &a, const Req &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Req &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} + +struct Resp { + std::string name; +}; +json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +bool fromJSON(const json::Value &V, Resp &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resp", T.name); +} +bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Resp &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} + +struct Evt { + std::string name; +}; +json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +bool fromJSON(const json::Value &V, Evt &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("evt", T.name); +} +bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Evt &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} + +using Message = std::variant; +json::Value toJSON(const Message &msg) { + return std::visit([](const auto &msg) { return toJSON(msg); }, msg); +} +bool fromJSON(const json::Value &V, Message &msg, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + if (O->get("req")) { + Req R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; + } + if (O->get("resp")) { + Resp R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; + } + if (O->get("evt")) { + Evt E; + if (!fromJSON(V, E, P)) + return false; + + msg = std::move(E); + return true; + } + P.report("unknown message type"); + return false; +} + +} // namespace test_protocol + +template +class JSONTransportTest : public PipePairTest { + protected: - std::unique_ptr transport; + MockMessageHandler message_handler; + std::unique_ptr transport; + MainLoop loop; void SetUp() override { PipePairTest::SetUp(); @@ -28,79 +155,231 @@ template class JSONTransportTest : public PipePairTest { File::eOpenOptionWriteOnly, NativeFile::Unowned)); } + + /// Run the transport MainLoop and return any messages received. + Error + Run(bool close_input = true, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + if (close_input) { + input.CloseWriteFileDescriptor(); + EXPECT_CALL(message_handler, OnClosed()).WillOnce([this]() { + loop.RequestTermination(); + }); + } + loop.AddCallback( + [](MainLoopBase &loop) { + loop.RequestTermination(); + FAIL() << "timeout"; + }, + timeout); + auto handle = transport->RegisterMessageHandler(loop, message_handler); + if (!handle) + return handle.takeError(); + + return loop.Run().takeError(); + } + + template void Write(Ts... args) { + std::string message; + for (const auto &arg : {args...}) + message += Encode(arg); + EXPECT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + } + + virtual std::string Encode(const json::Value &) = 0; }; -class HTTPDelimitedJSONTransportTest - : public JSONTransportTest { +class TestHTTPDelimitedJSONTransport final + : public HTTPDelimitedJSONTransport { public: - using JSONTransportTest::JSONTransportTest; + using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; }; -class JSONRPCTransportTest : public JSONTransportTest { +class HTTPDelimitedJSONTransportTest + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; + + std::string Encode(const json::Value &V) override { + std::string msg; + raw_string_ostream OS(msg); + OS << formatv("{0}", V); + return formatv("Content-Length: {0}\r\nContent-type: " + "text/json\r\n\r\n{1}", + msg.size(), msg) + .str(); + } }; -struct JSONTestType { - std::string str; +class TestJSONRPCTransport final + : public JSONRPCTransport { +public: + using JSONRPCTransport::JSONRPCTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; }; -llvm::json::Value toJSON(const JSONTestType &T) { - return llvm::json::Object{{"str", T.str}}; -} +class JSONRPCTransportTest + : public JSONTransportTest { +public: + using JSONTransportTest::JSONTransportTest; + + std::string Encode(const json::Value &V) override { + std::string msg; + raw_string_ostream OS(msg); + OS << formatv("{0}\n", V); + return msg; + } +}; -bool fromJSON(const llvm::json::Value &V, JSONTestType &T, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("str", T.str); -} } // namespace +// Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. +#ifndef _WIN32 +using namespace test_protocol; + TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { - std::string malformed_header = "COnTent-LenGth: -1{}\r\n\r\nnotjosn"; + std::string malformed_header = + "COnTent-LenGth: -1\r\nContent-Type: text/json\r\n\r\nnotjosn"; ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - FailedWithMessage( - "expected 'Content-Length: ' and got 'COnTent-LenGth: '")); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage("invalid content length: -1")); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - std::string json = R"json({"str": "foo"})json"; - std::string message = - formatv("Content-Length: {0}\r\n\r\n{1}", json.size(), json).str(); - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), - Succeeded()); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + Write(Req{"foo"}); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } -TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { +TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { + InSequence seq; + Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); + EXPECT_CALL(message_handler, Received(Req{"one"})); + EXPECT_CALL(message_handler, Received(Evt{"two"})); + EXPECT_CALL(message_handler, Received(Resp{"three"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { + std::string long_str = std::string( + HTTPDelimitedJSONTransport::kReadBufferSize * 2, 'x'); + Write(Req{long_str}); + EXPECT_CALL(message_handler, Received(Req{long_str})); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); + + EXPECT_CALL(message_handler, Received(Req{"foo"})); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); + ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - Failed()); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); + + EXPECT_CALL(message_handler, Received(Req{"foo"})); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + + // Run the main loop once for the initial read. + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); + + // zero-byte write. + ASSERT_THAT_EXPECTED(input.Write(part1.data(), 0), + Succeeded()); // zero-byte write. + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); + + // Write the remaining part of the message. + ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { + ASSERT_THAT_ERROR(Run(), Succeeded()); } +TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { + std::string json = R"json({"str": "foo"})json"; + std::string message = + formatv("Content-Length: {0}\r\nContent-type: text/json\r\n\r\n{1}", + json.size(), json) + .str(); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + // The error should indicate that there are unhandled contents. + ASSERT_THAT_ERROR(std::move(err), + Failed()); + }); + + // Write an incomplete message and close the handle. + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), + Succeeded()); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - Failed()); + transport = + std::make_unique(nullptr, nullptr); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), + FailedWithMessage("IO object is not valid.")); } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"json({"str":"foo"})json")); + R"({"req":"foo"})" + "Content-Length: 14\r\n\r\n" + R"({"resp":"bar"})" + "Content-Length: 13\r\n\r\n" + R"({"evt":"baz"})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -108,74 +387,94 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - llvm::Failed()); + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage(HasSubstr("Invalid JSON value"))); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, Read) { - std::string json = R"json({"str": "foo"})json"; - std::string message = formatv("{0}\n", json).str(); - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), - Succeeded()); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + Write(Message{Req{"foo"}}); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } -TEST_F(JSONRPCTransportTest, ReadWithEOF) { +TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { + InSequence seq; + Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); + EXPECT_CALL(message_handler, Received(Req{"one"})); + EXPECT_CALL(message_handler, Received(Evt{"two"})); + EXPECT_CALL(message_handler, Received(Resp{"three"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { + // Use a string longer than the chunk size to ensure we split the message + // across the chunk boundary. + std::string long_str = + std::string(JSONTransport::kReadBufferSize * 2, 'x'); + Write(Req{long_str}); + EXPECT_CALL(message_handler, Received(Req{long_str})); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(JSONRPCTransportTest, ReadPartialMessage) { + std::string message = R"({"req": "foo"})" + "\n"; + std::string part1 = message.substr(0, 7); + std::string part2 = message.substr(7); + + EXPECT_CALL(message_handler, Received(Req{"foo"})); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), Succeeded()); + + ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - Failed()); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(JSONRPCTransportTest, ReadWithEOF) { + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { + std::string message = R"json({"req": "foo")json"; + // Write an incomplete message and close the handle. + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), + Succeeded()); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + Failed()); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"json({"str":"foo"})json" + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" + "\n" + R"({"resp":"bar"})" + "\n" + R"({"evt":"baz"})" "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - Failed()); -} - -#ifndef _WIN32 -TEST_F(HTTPDelimitedJSONTransportTest, ReadWithTimeout) { - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - Failed()); + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), + FailedWithMessage("IO object is not valid.")); } -TEST_F(JSONRPCTransportTest, ReadWithTimeout) { - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - Failed()); -} - -// Windows CRT _read checks that the file descriptor is valid and calls a -// handler if not. This handler is normally a breakpoint, which looks like a -// crash when not handled by a debugger. -// https://learn.microsoft.com/en-us/%20cpp/c-runtime-library/reference/read?view=msvc-170 -TEST_F(HTTPDelimitedJSONTransportTest, ReadAfterClosed) { - input.CloseReadFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - llvm::Failed()); -} - -TEST_F(JSONRPCTransportTest, ReadAfterClosed) { - input.CloseReadFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read(std::chrono::milliseconds(1)), - llvm::Failed()); -} #endif diff --git a/lldb/unittests/Protocol/CMakeLists.txt b/lldb/unittests/Protocol/CMakeLists.txt index bbac69611e011..f877517ea233d 100644 --- a/lldb/unittests/Protocol/CMakeLists.txt +++ b/lldb/unittests/Protocol/CMakeLists.txt @@ -1,5 +1,6 @@ add_lldb_unittest(ProtocolTests ProtocolMCPTest.cpp + ProtocolMCPServerTest.cpp LINK_LIBS lldbHost diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp new file mode 100644 index 0000000000000..9fa446133d46f --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -0,0 +1,304 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ProtocolMCPTestUtilities.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" +#include "TestingSupport/Host/PipeTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "lldb/Host/Socket.h" +#include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Resource.h" +#include "lldb/Protocol/MCP/Server.h" +#include "lldb/Protocol/MCP/Tool.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/JSON.h" +#include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; +using namespace lldb_protocol::mcp; + +namespace { +class TestMCPTransport final : public MCPTransport { +public: + TestMCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + : lldb_protocol::mcp::MCPTransport(in, out, "unittest") {} + + using MCPTransport::Write; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; +}; + +class TestServer : public Server { +public: + using Server::Server; +}; + +/// Test tool that returns it argument as text. +class TestTool : public Tool { +public: + using Tool::Tool; + + llvm::Expected Call(const ToolArguments &args) override { + std::string argument; + if (const json::Object *args_obj = + std::get(args).getAsObject()) { + if (const json::Value *s = args_obj->get("arguments")) { + argument = s->getAsString().value_or(""); + } + } + + CallToolResult text_result; + text_result.content.emplace_back(TextContent{{argument}}); + return text_result; + } +}; + +class TestResourceProvider : public ResourceProvider { + using ResourceProvider::ResourceProvider; + + std::vector GetResources() const override { + std::vector resources; + + Resource resource; + resource.uri = "lldb://foo/bar"; + resource.name = "name"; + resource.description = "description"; + resource.mimeType = "application/json"; + + resources.push_back(resource); + return resources; + } + + llvm::Expected + ReadResource(llvm::StringRef uri) const override { + if (uri != "lldb://foo/bar") + return llvm::make_error(uri.str()); + + TextResourceContents contents; + contents.uri = "lldb://foo/bar"; + contents.mimeType = "application/json"; + contents.text = "foobar"; + + ReadResourceResult result; + result.contents.push_back(contents); + return result; + } +}; + +/// Test tool that returns an error. +class ErrorTool : public Tool { +public: + using Tool::Tool; + + llvm::Expected Call(const ToolArguments &args) override { + return llvm::createStringError("error"); + } +}; + +/// Test tool that fails but doesn't return an error. +class FailTool : public Tool { +public: + using Tool::Tool; + + llvm::Expected Call(const ToolArguments &args) override { + CallToolResult text_result; + text_result.content.emplace_back(TextContent{{"failed"}}); + text_result.isError = true; + return text_result; + } +}; + +class ProtocolServerMCPTest : public PipePairTest { +public: + SubsystemRAII subsystems; + + std::unique_ptr transport_up; + std::unique_ptr server_up; + MainLoop loop; + MockMessageHandler message_handler; + + llvm::Error Write(llvm::StringRef message) { + llvm::Expected value = json::parse(message); + if (!value) + return value.takeError(); + return transport_up->Write(*value); + } + + llvm::Error Write(json::Value value) { return transport_up->Write(value); } + + /// Run the transport MainLoop and return any messages received. + llvm::Error + Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { + loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, + timeout); + auto handle = transport_up->RegisterMessageHandler(loop, message_handler); + if (!handle) + return handle.takeError(); + + return server_up->Run(); + } + + void SetUp() override { + PipePairTest::SetUp(); + + transport_up = std::make_unique( + std::make_shared(input.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(output.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); + + server_up = std::make_unique( + "lldb-mcp", "0.1.0", + std::make_unique( + std::make_shared(output.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(input.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)), + loop); + } +}; + +template +Request make_request(StringLiteral method, T &¶ms, Id id = 1) { + return Request{id, method.str(), toJSON(std::forward(params))}; +} + +template Response make_response(T &&result, Id id = 1) { + return Response{id, std::forward(result)}; +} + +} // namespace + +TEST_F(ProtocolServerMCPTest, Initialization) { + Request request = make_request( + "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{}, + /*clientInfo=*/{"lldb-unit", "0.1.0"}}); + Response response = make_response( + InitializeResult{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{/*supportsToolsList=*/true}, + /*serverInfo=*/{"lldb-mcp", "0.1.0"}}); + + ASSERT_THAT_ERROR(Write(request), Succeeded()); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(ProtocolServerMCPTest, ToolsList) { + server_up->AddTool(std::make_unique("test", "test tool")); + + Request request = make_request("tools/list", Void{}, /*id=*/"one"); + + ToolDefinition test_tool; + test_tool.name = "test"; + test_tool.description = "test tool"; + test_tool.inputSchema = json::Object{{"type", "object"}}; + + Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one"); + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(ProtocolServerMCPTest, ResourcesList) { + server_up->AddResourceProvider(std::make_unique()); + + Request request = make_request("resources/list", Void{}); + Response response = make_response(ListResourcesResult{ + {{/*uri=*/"lldb://foo/bar", /*name=*/"name", + /*description=*/"description", /*mimeType=*/"application/json"}}}); + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(ProtocolServerMCPTest, ToolsCall) { + server_up->AddTool(std::make_unique("test", "test tool")); + + Request request = make_request( + "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }}); + Response response = make_response(CallToolResult{{{/*text=*/"foo"}}}); + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallError) { + server_up->AddTool(std::make_unique("error", "error tool")); + + Request request = make_request( + "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }}); + Response response = + make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError, + /*message=*/"error"}); + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallFail) { + server_up->AddTool(std::make_unique("fail", "fail tool")); + + Request request = make_request( + "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }}); + Response response = + make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true}); + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(ProtocolServerMCPTest, NotificationInitialized) { + bool handler_called = false; + std::condition_variable cv; + + server_up->AddNotificationHandler( + "notifications/initialized", + [&](const Notification ¬ification) { handler_called = true; }); + llvm::StringLiteral request = + R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_TRUE(handler_called); +} diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index ea19922522ffe..396e361e873fe 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "ProtocolMCPTestUtilities.h" #include "TestingSupport/TestUtilities.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Testing/Support/Error.h" @@ -54,31 +55,16 @@ TEST(ProtocolMCPTest, Notification) { EXPECT_EQ(notification.params, deserialized_notification->params); } -TEST(ProtocolMCPTest, ToolCapability) { - ToolCapability tool_capability; - tool_capability.listChanged = true; +TEST(ProtocolMCPTest, ServerCapabilities) { + ServerCapabilities capabilities; + capabilities.supportsToolsList = true; - llvm::Expected deserialized_tool_capability = - roundtripJSON(tool_capability); - ASSERT_THAT_EXPECTED(deserialized_tool_capability, llvm::Succeeded()); - - EXPECT_EQ(tool_capability.listChanged, - deserialized_tool_capability->listChanged); -} - -TEST(ProtocolMCPTest, Capabilities) { - ToolCapability tool_capability; - tool_capability.listChanged = true; - - Capabilities capabilities; - capabilities.tools = tool_capability; - - llvm::Expected deserialized_capabilities = + llvm::Expected deserialized_capabilities = roundtripJSON(capabilities); ASSERT_THAT_EXPECTED(deserialized_capabilities, llvm::Succeeded()); - EXPECT_EQ(capabilities.tools.listChanged, - deserialized_capabilities->tools.listChanged); + EXPECT_EQ(capabilities.supportsToolsList, + deserialized_capabilities->supportsToolsList); } TEST(ProtocolMCPTest, TextContent) { @@ -92,18 +78,18 @@ TEST(ProtocolMCPTest, TextContent) { EXPECT_EQ(text_content.text, deserialized_text_content->text); } -TEST(ProtocolMCPTest, TextResult) { +TEST(ProtocolMCPTest, CallToolResult) { TextContent text_content1; text_content1.text = "Text 1"; TextContent text_content2; text_content2.text = "Text 2"; - TextResult text_result; + CallToolResult text_result; text_result.content = {text_content1, text_content2}; text_result.isError = true; - llvm::Expected deserialized_text_result = + llvm::Expected deserialized_text_result = roundtripJSON(text_result); ASSERT_THAT_EXPECTED(deserialized_text_result, llvm::Succeeded()); @@ -237,13 +223,13 @@ TEST(ProtocolMCPTest, ResourceWithoutOptionals) { EXPECT_TRUE(deserialized_resource->mimeType.empty()); } -TEST(ProtocolMCPTest, ResourceContents) { - ResourceContents contents; +TEST(ProtocolMCPTest, TextResourceContents) { + TextResourceContents contents; contents.uri = "resource://example/content"; contents.text = "This is the content of the resource"; contents.mimeType = "text/plain"; - llvm::Expected deserialized_contents = + llvm::Expected deserialized_contents = roundtripJSON(contents); ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); @@ -252,12 +238,12 @@ TEST(ProtocolMCPTest, ResourceContents) { EXPECT_EQ(contents.mimeType, deserialized_contents->mimeType); } -TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { - ResourceContents contents; +TEST(ProtocolMCPTest, TextResourceContentsWithoutMimeType) { + TextResourceContents contents; contents.uri = "resource://example/content-no-mime"; contents.text = "Content without mime type specified"; - llvm::Expected deserialized_contents = + llvm::Expected deserialized_contents = roundtripJSON(contents); ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); @@ -266,21 +252,22 @@ TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { EXPECT_TRUE(deserialized_contents->mimeType.empty()); } -TEST(ProtocolMCPTest, ResourceResult) { - ResourceContents contents1; +TEST(ProtocolMCPTest, ReadResourceResult) { + TextResourceContents contents1; contents1.uri = "resource://example/content1"; contents1.text = "First resource content"; contents1.mimeType = "text/plain"; - ResourceContents contents2; + TextResourceContents contents2; contents2.uri = "resource://example/content2"; contents2.text = "Second resource content"; contents2.mimeType = "application/json"; - ResourceResult result; + ReadResourceResult result; result.contents = {contents1, contents2}; - llvm::Expected deserialized_result = roundtripJSON(result); + llvm::Expected deserialized_result = + roundtripJSON(result); ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); ASSERT_EQ(result.contents.size(), deserialized_result->contents.size()); @@ -296,10 +283,11 @@ TEST(ProtocolMCPTest, ResourceResult) { deserialized_result->contents[1].mimeType); } -TEST(ProtocolMCPTest, ResourceResultEmpty) { - ResourceResult result; +TEST(ProtocolMCPTest, ReadResourceResultEmpty) { + ReadResourceResult result; - llvm::Expected deserialized_result = roundtripJSON(result); + llvm::Expected deserialized_result = + roundtripJSON(result); ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); EXPECT_TRUE(deserialized_result->contents.empty()); diff --git a/lldb/unittests/Protocol/ProtocolMCPTestUtilities.h b/lldb/unittests/Protocol/ProtocolMCPTestUtilities.h new file mode 100644 index 0000000000000..f8a14f4be03c9 --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPTestUtilities.h @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_PROTOCOL_PROTOCOLMCPTESTUTILITIES_H +#define LLDB_UNITTESTS_PROTOCOL_PROTOCOLMCPTESTUTILITIES_H + +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" // IWYU pragma: keep +#include "gtest/gtest.h" // IWYU pragma: keep +#include +#include + +namespace lldb_protocol::mcp { + +inline void PrintTo(const Request &req, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(req)).str(); +} + +inline void PrintTo(const Response &resp, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(resp)).str(); +} + +inline void PrintTo(const Notification ¬e, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(note)).str(); +} + +inline void PrintTo(const Message &message, std::ostream *os) { + return std::visit([os](auto &&message) { return PrintTo(message, os); }, + message); +} + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/unittests/ProtocolServer/CMakeLists.txt b/lldb/unittests/ProtocolServer/CMakeLists.txt deleted file mode 100644 index 6117430b35bf0..0000000000000 --- a/lldb/unittests/ProtocolServer/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_lldb_unittest(ProtocolServerTests - ProtocolMCPServerTest.cpp - - LINK_LIBS - lldbCore - lldbUtility - lldbHost - lldbPluginPlatformMacOSX - lldbPluginProtocolServerMCP - LLVMTestingSupport - ) diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp deleted file mode 100644 index 7890d3f69b9e1..0000000000000 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ /dev/null @@ -1,325 +0,0 @@ -//===-- ProtocolServerMCPTest.cpp -----------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" -#include "Plugins/Protocol/MCP/ProtocolServerMCP.h" -#include "TestingSupport/Host/SocketTestUtilities.h" -#include "TestingSupport/SubsystemRAII.h" -#include "lldb/Core/Debugger.h" -#include "lldb/Core/ProtocolServer.h" -#include "lldb/Host/FileSystem.h" -#include "lldb/Host/HostInfo.h" -#include "lldb/Host/JSONTransport.h" -#include "lldb/Host/Socket.h" -#include "lldb/Protocol/MCP/MCPError.h" -#include "lldb/Protocol/MCP/Protocol.h" -#include "llvm/Testing/Support/Error.h" -#include "gtest/gtest.h" - -using namespace llvm; -using namespace lldb; -using namespace lldb_private; -using namespace lldb_protocol::mcp; - -namespace { -class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { -public: - using ProtocolServerMCP::AddNotificationHandler; - using ProtocolServerMCP::AddRequestHandler; - using ProtocolServerMCP::AddResourceProvider; - using ProtocolServerMCP::AddTool; - using ProtocolServerMCP::GetSocket; - using ProtocolServerMCP::ProtocolServerMCP; -}; - -class TestJSONTransport : public lldb_private::JSONRPCTransport { -public: - using JSONRPCTransport::JSONRPCTransport; - using JSONRPCTransport::ReadImpl; - using JSONRPCTransport::WriteImpl; -}; - -/// Test tool that returns it argument as text. -class TestTool : public Tool { -public: - using Tool::Tool; - - virtual llvm::Expected Call(const ToolArguments &args) override { - std::string argument; - if (const json::Object *args_obj = - std::get(args).getAsObject()) { - if (const json::Value *s = args_obj->get("arguments")) { - argument = s->getAsString().value_or(""); - } - } - - TextResult text_result; - text_result.content.emplace_back(TextContent{{argument}}); - return text_result; - } -}; - -class TestResourceProvider : public ResourceProvider { - using ResourceProvider::ResourceProvider; - - virtual std::vector GetResources() const override { - std::vector resources; - - Resource resource; - resource.uri = "lldb://foo/bar"; - resource.name = "name"; - resource.description = "description"; - resource.mimeType = "application/json"; - - resources.push_back(resource); - return resources; - } - - virtual llvm::Expected - ReadResource(llvm::StringRef uri) const override { - if (uri != "lldb://foo/bar") - return llvm::make_error(uri.str()); - - ResourceContents contents; - contents.uri = "lldb://foo/bar"; - contents.mimeType = "application/json"; - contents.text = "foobar"; - - ResourceResult result; - result.contents.push_back(contents); - return result; - } -}; - -/// Test tool that returns an error. -class ErrorTool : public Tool { -public: - using Tool::Tool; - - virtual llvm::Expected Call(const ToolArguments &args) override { - return llvm::createStringError("error"); - } -}; - -/// Test tool that fails but doesn't return an error. -class FailTool : public Tool { -public: - using Tool::Tool; - - virtual llvm::Expected Call(const ToolArguments &args) override { - TextResult text_result; - text_result.content.emplace_back(TextContent{{"failed"}}); - text_result.isError = true; - return text_result; - } -}; - -class ProtocolServerMCPTest : public ::testing::Test { -public: - SubsystemRAII subsystems; - DebuggerSP m_debugger_sp; - - lldb::IOObjectSP m_io_sp; - std::unique_ptr m_transport_up; - std::unique_ptr m_server_up; - - static constexpr llvm::StringLiteral k_localhost = "localhost"; - - llvm::Error Write(llvm::StringRef message) { - return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); - } - - llvm::Expected Read() { - return m_transport_up->ReadImpl(std::chrono::milliseconds(100)); - } - - void SetUp() { - // Create a debugger. - ArchSpec arch("arm64-apple-macosx-"); - Platform::SetHostPlatform( - PlatformRemoteMacOSX::CreateInstance(true, &arch)); - m_debugger_sp = Debugger::CreateInstance(); - - // Create & start the server. - ProtocolServer::Connection connection; - connection.protocol = Socket::SocketProtocol::ProtocolTcp; - connection.name = llvm::formatv("{0}:0", k_localhost).str(); - m_server_up = std::make_unique(); - m_server_up->AddTool(std::make_unique("test", "test tool")); - m_server_up->AddResourceProvider(std::make_unique()); - ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); - - // Connect to the server over a TCP socket. - auto connect_socket_up = std::make_unique(true); - ASSERT_THAT_ERROR(connect_socket_up - ->Connect(llvm::formatv("{0}:{1}", k_localhost, - static_cast( - m_server_up->GetSocket()) - ->GetLocalPortNumber()) - .str()) - .ToError(), - llvm::Succeeded()); - - // Set up JSON transport for the client. - m_io_sp = std::move(connect_socket_up); - m_transport_up = std::make_unique(m_io_sp, m_io_sp); - } - - void TearDown() { - // Stop the server. - ASSERT_THAT_ERROR(m_server_up->Stop(), llvm::Succeeded()); - } -}; - -} // namespace - -TEST_F(ProtocolServerMCPTest, Initialization) { - llvm::StringLiteral request = - R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; - llvm::StringLiteral response = - R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - llvm::Expected response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); -} - -TEST_F(ProtocolServerMCPTest, ToolsList) { - llvm::StringLiteral request = - R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; - llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - llvm::Expected response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); -} - -TEST_F(ProtocolServerMCPTest, ResourcesList) { - llvm::StringLiteral request = - R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; - llvm::StringLiteral response = - R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - llvm::Expected response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); -} - -TEST_F(ProtocolServerMCPTest, ToolsCall) { - llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; - llvm::StringLiteral response = - R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - llvm::Expected response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); -} - -TEST_F(ProtocolServerMCPTest, ToolsCallError) { - m_server_up->AddTool(std::make_unique("error", "error tool")); - - llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; - llvm::StringLiteral response = - R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - llvm::Expected response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); -} - -TEST_F(ProtocolServerMCPTest, ToolsCallFail) { - m_server_up->AddTool(std::make_unique("fail", "fail tool")); - - llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; - llvm::StringLiteral response = - R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - llvm::Expected response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); -} - -TEST_F(ProtocolServerMCPTest, NotificationInitialized) { - bool handler_called = false; - std::condition_variable cv; - std::mutex mutex; - - m_server_up->AddNotificationHandler( - "notifications/initialized", [&](const Notification ¬ification) { - { - std::lock_guard lock(mutex); - handler_called = true; - } - cv.notify_all(); - }); - llvm::StringLiteral request = - R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - std::unique_lock lock(mutex); - cv.wait(lock, [&] { return handler_called; }); -} diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h new file mode 100644 index 0000000000000..5a9eb8e59f2b6 --- /dev/null +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H + +#include "lldb/Host/JSONTransport.h" +#include "gmock/gmock.h" + +template +class MockMessageHandler final + : public lldb_private::Transport::MessageHandler { +public: + MOCK_METHOD(void, Received, (const Evt &), (override)); + MOCK_METHOD(void, Received, (const Req &), (override)); + MOCK_METHOD(void, Received, (const Resp &), (override)); + MOCK_METHOD(void, OnError, (llvm::Error), (override)); + MOCK_METHOD(void, OnClosed, (), (override)); +}; + +#endif