Skip to content

Commit

Permalink
Cancelable Initialization
Browse files Browse the repository at this point in the history
This commit provides additional initialization methods to Connection in
order to support CTRL + C sigterm handling.
  • Loading branch information
schrieveslaach committed Nov 21, 2023
1 parent 4513651 commit cdddd07
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 10 deletions.
24 changes: 23 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/lsp-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ crossbeam-channel = "0.5.6"

[dev-dependencies]
lsp-types = "=0.94"
ctrlc = "3.4.1"
169 changes: 160 additions & 9 deletions lib/lsp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::{
net::{TcpListener, TcpStream, ToSocketAddrs},
};

use crossbeam_channel::{Receiver, Sender};
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};

pub use crate::{
error::{ExtractError, ProtocolError},
Expand Down Expand Up @@ -113,11 +113,62 @@ impl Connection {
/// }
/// ```
pub fn initialize_start(&self) -> Result<(RequestId, serde_json::Value), ProtocolError> {
loop {
break match self.receiver.recv() {
Ok(Message::Request(req)) if req.is_initialize() => Ok((req.id, req.params)),
self.initialize_start_while(|| true)
}

/// Starts the initialization process by waiting for an initialize as described in
/// [`Self::initialize_start`] as long as `running` returns
/// `true` while the return value can be changed through a sig handler such as `CTRL + C`.
///
/// # Example
///
/// ```rust
/// use std::sync::atomic::{AtomicBool, Ordering};
/// use std::sync::Arc;
/// # use std::error::Error;
/// # use lsp_types::{ClientCapabilities, InitializeParams, ServerCapabilities};
/// # use lsp_server::{Connection, Message, Request, RequestId, Response};
/// # fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
/// let running = Arc::new(AtomicBool::new(true));
/// # running.store(true, Ordering::SeqCst);
/// let r = running.clone();
///
/// ctrlc::set_handler(move || {
/// r.store(false, Ordering::SeqCst);
/// }).expect("Error setting Ctrl-C handler");
///
/// let (connection, io_threads) = Connection::stdio();
///
/// let res = connection.initialize_start_while(|| running.load(Ordering::SeqCst));
/// # assert!(res.is_err());
///
/// # Ok(())
/// # }
/// ```
pub fn initialize_start_while<C>(
&self,
running: C,
) -> Result<(RequestId, serde_json::Value), ProtocolError>
where
C: Fn() -> bool,
{
while running() {
let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) {
Ok(msg) => msg,
Err(RecvTimeoutError::Timeout) => {
continue;
}
Err(e) => {
return Err(ProtocolError(format!(
"expected initialize request, got error: {e}"
)))
}
};

match msg {
Message::Request(req) if req.is_initialize() => return Ok((req.id, req.params)),
// Respond to non-initialize requests with ServerNotInitialized
Ok(Message::Request(req)) => {
Message::Request(req) => {
let resp = Response::new_err(
req.id.clone(),
ErrorCode::ServerNotInitialized as i32,
Expand All @@ -126,15 +177,18 @@ impl Connection {
self.sender.send(resp.into()).unwrap();
continue;
}
Ok(Message::Notification(n)) if !n.is_exit() => {
Message::Notification(n) if !n.is_exit() => {
continue;
}
Ok(msg) => Err(ProtocolError(format!("expected initialize request, got {msg:?}"))),
Err(e) => {
Err(ProtocolError(format!("expected initialize request, got error: {e}")))
msg => {
return Err(ProtocolError(format!("expected initialize request, got {msg:?}")));
}
};
}

return Err(ProtocolError(String::from(
"Initialization has been aborted during initialization",
)));
}

/// Finishes the initialization process by sending an `InitializeResult` to the client
Expand All @@ -156,6 +210,51 @@ impl Connection {
}
}

/// Finishes the initialization process as described in [`Self::initialize_finish`] as
/// long as `running` returns `true` while the return value can be changed through a sig
/// handler such as `CTRL + C`.
pub fn initialize_finish_while<C>(
&self,
initialize_id: RequestId,
initialize_result: serde_json::Value,
running: C,
) -> Result<(), ProtocolError>
where
C: Fn() -> bool,
{
let resp = Response::new_ok(initialize_id, initialize_result);
self.sender.send(resp.into()).unwrap();

while running() {
let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) {
Ok(msg) => msg,
Err(RecvTimeoutError::Timeout) => {
continue;
}
Err(e) => {
return Err(ProtocolError(format!(
"expected initialized notification, got error: {e}",
)));
}
};

match msg {
Message::Notification(n) if n.is_initialized() => {
return Ok(());
}
msg => {
return Err(ProtocolError(format!(
r#"expected initialized notification, got: {msg:?}"#
)));
}
}
}

return Err(ProtocolError(String::from(
"Initialization has been aborted during initialization",
)));
}

/// Initialize the connection. Sends the server capabilities
/// to the client and returns the serialized client capabilities
/// on success. If more fine-grained initialization is required use
Expand Down Expand Up @@ -198,6 +297,58 @@ impl Connection {
Ok(params)
}

/// Initialize the connection as described in [`Self::initialize`] as long as `running` returns
/// `true` while the return value can be changed through a sig handler such as `CTRL + C`.
///
/// # Example
///
/// ```rust
/// use std::sync::atomic::{AtomicBool, Ordering};
/// use std::sync::Arc;
/// # use std::error::Error;
/// # use lsp_types::ServerCapabilities;
/// # use lsp_server::{Connection, Message, Request, RequestId, Response};
///
/// # fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
/// let running = Arc::new(AtomicBool::new(true));
/// # running.store(true, Ordering::SeqCst);
/// let r = running.clone();
///
/// ctrlc::set_handler(move || {
/// r.store(false, Ordering::SeqCst);
/// }).expect("Error setting Ctrl-C handler");
///
/// let (connection, io_threads) = Connection::stdio();
///
/// let server_capabilities = serde_json::to_value(&ServerCapabilities::default()).unwrap();
/// let initialization_params = connection.initialize_while(
/// server_capabilities,
/// || running.load(Ordering::SeqCst)
/// );
///
/// # assert!(initialization_params.is_err());
/// # Ok(())
/// # }
/// ```
pub fn initialize_while<C>(
&self,
server_capabilities: serde_json::Value,
running: C,
) -> Result<serde_json::Value, ProtocolError>
where
C: Fn() -> bool,
{
let (id, params) = self.initialize_start_while(&running)?;

let initialize_data = serde_json::json!({
"capabilities": server_capabilities,
});

self.initialize_finish_while(id, initialize_data, running)?;

Ok(params)
}

/// If `req` is `Shutdown`, respond to it and return `true`, otherwise return `false`
pub fn handle_shutdown(&self, req: &Request) -> Result<bool, ProtocolError> {
if !req.is_shutdown() {
Expand Down

0 comments on commit cdddd07

Please sign in to comment.