diff --git a/proxy/src/decompile.rs b/proxy/src/decompile.rs new file mode 100644 index 0000000..1d9e6b6 --- /dev/null +++ b/proxy/src/decompile.rs @@ -0,0 +1,143 @@ +use serde_json::{json, Value}; +use std::{ + collections::{hash_map::DefaultHasher, HashMap}, + env, fs, + hash::{Hash, Hasher}, + io::Write, + path::{Path, PathBuf}, + sync::{mpsc, Arc, Mutex}, +}; + +use crate::{lsp::encode_lsp, lsp_error, lsp_warn}; + +const DECOMPILED_DIR: &str = "jdtls-decompiled"; + +/// Convert a `PathBuf` to a proper `file://` URI. +/// +/// On Unix the path already starts with `/`, so `file://` + path gives us +/// the correct `file:///…` form with no extra work. +/// +/// On Windows we must replace `\` with `/` and prepend `file:///` before the +/// drive letter so that we get `file:///C:/…` instead of `file://C:\…`. +#[cfg(unix)] +fn path_to_file_uri(path: &Path) -> String { + format!("file://{}", path.display()) +} + +#[cfg(windows)] +fn path_to_file_uri(path: &Path) -> String { + let s = path.display().to_string().replace('\\', "/"); + format!("file:///{s}") +} + +fn cache_dir() -> PathBuf { + env::temp_dir().join(DECOMPILED_DIR) +} + +fn cache_path(uri: &str) -> PathBuf { + let mut hasher = DefaultHasher::new(); + uri.hash(&mut hasher); + let hex = format!("{:016x}", hasher.finish()); + + // jdt://contents/java.base/java.util/ArrayList.java?=.../%3Cjava.util%28ArrayList.class + // The class name is between the last %28 (URL-encoded '(') and .class at the end + let name = uri + .rsplit_once("%28") + .and_then(|(_, rest)| rest.strip_suffix(".class")) + .or_else(|| { + uri.split('?') + .next() + .and_then(|path| path.rsplit('/').next()) + .and_then(|seg| seg.strip_suffix(".java").or(seg.strip_suffix(".class"))) + }) + .unwrap_or("Decompiled"); + + cache_dir().join(format!("{name}-{hex}.java")) +} + +/// Send `java/classFileContents` to JDTLS and wait for the response. +fn fetch_class_contents( + uri: &str, + writer: &Arc>, + pending: &Arc>>>, + request_id: Value, +) -> Option { + let (tx, rx) = mpsc::channel(); + pending.lock().unwrap().insert(request_id.clone(), tx); + + let req = encode_lsp(&json!({ + "jsonrpc": "2.0", + "id": request_id, + "method": "java/classFileContents", + "params": { "uri": uri } + })); + { + let mut w = writer.lock().unwrap(); + let _ = w.write_all(req.as_bytes()); + let _ = w.flush(); + } + + match rx.recv_timeout(std::time::Duration::from_secs(10)) { + Ok(resp) => { + let content = resp.get("result")?.as_str()?; + Some(content.to_string()) + } + Err(_) => { + lsp_warn!("[decompile] Timed out fetching class contents for {uri}"); + None + } + } +} + +fn resolve_jdt_uri( + uri: &str, + writer: &Arc>, + pending: &Arc>>>, + request_id: Value, +) -> Option { + let path = cache_path(uri); + if path.exists() { + return Some(path_to_file_uri(&path)); + } + + let content = fetch_class_contents(uri, writer, pending, request_id)?; + let _ = fs::create_dir_all(cache_dir()); + match fs::write(&path, &content) { + Ok(_) => Some(path_to_file_uri(&path)), + Err(e) => { + lsp_error!("[decompile] Failed to write {}: {e}", path.display()); + None + } + } +} + +/// Rewrite any `jdt://` URIs in a definition/typeDefinition/implementation response. +/// Returns `true` if any URI was rewritten. +pub fn rewrite_jdt_locations( + msg: &mut Value, + writer: &Arc>, + pending: &Arc>>>, + next_id: &mut impl FnMut() -> Value, +) -> bool { + let results = match msg.get_mut("result") { + Some(Value::Array(arr)) => arr.iter_mut().collect::>(), + Some(obj @ Value::Object(_)) => vec![obj], + _ => return false, + }; + + let mut rewritten = false; + for loc in results { + for key in &["uri", "targetUri"] { + if let Some(Value::String(uri)) = loc.get(key) { + if uri.starts_with("jdt://") { + let jdt_uri = uri.clone(); + if let Some(file_uri) = resolve_jdt_uri(&jdt_uri, writer, pending, next_id()) { + loc[*key] = Value::String(file_uri); + rewritten = true; + } + } + } + } + } + rewritten +} diff --git a/proxy/src/lsp.rs b/proxy/src/lsp.rs index ac35ae8..eeb343e 100644 --- a/proxy/src/lsp.rs +++ b/proxy/src/lsp.rs @@ -1,5 +1,5 @@ use serde::Serialize; -use std::io::{self, Read}; +use std::io::{self, Read, Write}; pub const CONTENT_LENGTH: &str = "Content-Length"; pub const HEADER_SEP: &[u8] = b"\r\n\r\n"; @@ -58,3 +58,17 @@ pub fn encode_lsp(value: &impl Serialize) -> String { let json = serde_json::to_string(value).unwrap(); format!("{CONTENT_LENGTH}: {}\r\n\r\n{json}", json.len()) } + +/// Write raw LSP bytes to a writer, flushing afterward. +pub fn write_raw(w: &mut impl Write, raw: &[u8]) { + let _ = w.write_all(raw); + let _ = w.flush(); +} + +/// Encode a value as an LSP message and write it to stdout. +pub fn write_to_stdout(value: &impl Serialize) { + let out = encode_lsp(value); + let mut w = io::stdout().lock(); + let _ = w.write_all(out.as_bytes()); + let _ = w.flush(); +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index b6ab51b..becfbdd 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -1,16 +1,18 @@ mod completions; +mod decompile; mod http; mod log; mod lsp; mod platform; use completions::{should_sort_completions, sort_completions_by_param_count}; +use decompile::rewrite_jdt_locations; use http::handle_http; -use lsp::{encode_lsp, parse_lsp_content, LspReader}; +use lsp::{parse_lsp_content, write_raw, write_to_stdout, LspReader}; use platform::spawn_parent_monitor; use serde_json::Value; use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, env, fs, io::{self, BufReader, Write}, net::TcpListener, @@ -88,15 +90,33 @@ fn main() { let id_counter = Arc::new(AtomicU64::new(1)); - // --- Thread 1: Zed stdin -> JDTLS stdin (passthrough) --- + // Track definition/typeDefinition/implementation request IDs for jdt:// rewriting + let definition_ids: Arc>> = Arc::new(Mutex::new(HashSet::new())); + + // --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) --- let stdin_writer = Arc::clone(&child_stdin); let alive_stdin = Arc::clone(&alive); + let def_ids_in = Arc::clone(&definition_ids); thread::spawn(move || { let stdin = io::stdin().lock(); let mut reader = LspReader::new(stdin); while alive_stdin.load(Ordering::Relaxed) { match reader.read_message() { Ok(Some(raw)) => { + if let Some(msg) = parse_lsp_content(&raw) { + if let Some(method) = msg.get("method").and_then(|m| m.as_str()) { + if matches!( + method, + "textDocument/definition" + | "textDocument/typeDefinition" + | "textDocument/implementation" + ) { + if let Some(id) = msg.get("id").cloned() { + def_ids_in.lock().unwrap().insert(id); + } + } + } + } let mut w = stdin_writer.lock().unwrap(); if w.write_all(&raw).is_err() || w.flush().is_err() { break; @@ -108,19 +128,21 @@ fn main() { alive_stdin.store(false, Ordering::Relaxed); }); - // --- Thread 2: JDTLS stdout -> modify completions -> Zed stdout / resolve pending --- + // --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending --- let pending_out = Arc::clone(&pending); let alive_out = Arc::clone(&alive); + let def_ids_out = Arc::clone(&definition_ids); + let decompile_writer = Arc::clone(&child_stdin); + let decompile_pending = Arc::clone(&pending); + let decompile_counter = Arc::clone(&id_counter); + let decompile_proxy_id = proxy_id.clone(); thread::spawn(move || { let mut reader = LspReader::new(BufReader::new(child_stdout)); - let stdout = io::stdout(); while alive_out.load(Ordering::Relaxed) { match reader.read_message() { Ok(Some(raw)) => { let Some(mut msg) = parse_lsp_content(&raw) else { - let mut w = stdout.lock(); - let _ = w.write_all(&raw); - let _ = w.flush(); + write_raw(&mut io::stdout().lock(), &raw); continue; }; @@ -132,20 +154,36 @@ fn main() { } } + // Rewrite jdt:// URIs in definition responses + // Spawns a thread so this loop stays unblocked and can + // route the java/classFileContents response back via `pending`. + if let Some(id) = msg.get("id").cloned() { + if def_ids_out.lock().unwrap().remove(&id) { + let writer = Arc::clone(&decompile_writer); + let pending = Arc::clone(&decompile_pending); + let pid = decompile_proxy_id.clone(); + let counter = Arc::clone(&decompile_counter); + thread::spawn(move || { + let mut next_id = move || { + let seq = counter.fetch_add(1, Ordering::Relaxed); + Value::String(format!("{pid}-decompile-{seq}")) + }; + rewrite_jdt_locations(&mut msg, &writer, &pending, &mut next_id); + write_to_stdout(&msg); + }); + continue; + } + } + // Sort completion responses by param count if should_sort_completions(&msg) { sort_completions_by_param_count(&mut msg); - let out = encode_lsp(&msg); - let mut w = stdout.lock(); - let _ = w.write_all(out.as_bytes()); - let _ = w.flush(); + write_to_stdout(&msg); continue; } // Passthrough - let mut w = stdout.lock(); - let _ = w.write_all(&raw); - let _ = w.flush(); + write_raw(&mut io::stdout().lock(), &raw); } Ok(None) | Err(_) => break, } diff --git a/src/java.rs b/src/java.rs index e74afdb..9a75d9c 100644 --- a/src/java.rs +++ b/src/java.rs @@ -380,21 +380,36 @@ impl Extension for Java { })?; } - let options = LspSettings::for_worktree(language_server_id.as_ref(), worktree) + let mut options = LspSettings::for_worktree(language_server_id.as_ref(), worktree) .map(|lsp_settings| lsp_settings.initialization_options) - .map_err(|err| format!("Failed to get LSP settings for worktree: {err}"))?; + .map_err(|err| format!("Failed to get LSP settings for worktree: {err}"))? + .unwrap_or_else(|| json!({})); + + // Inject extendedClientCapabilities defaults if not already set by the user + let caps = options + .as_object_mut() + .unwrap() + .entry("extendedClientCapabilities") + .or_insert_with(|| json!({})); + let caps_obj = caps.as_object_mut().unwrap(); + caps_obj + .entry("classFileContentsSupport") + .or_insert(json!(true)); + caps_obj + .entry("resolveAdditionalTextEditsSupport") + .or_insert(json!(true)); if self.debugger().is_ok_and(|v| v.loaded()) { return Ok(Some( self.debugger()? - .inject_plugin_into_options(options) + .inject_plugin_into_options(Some(options)) .map_err(|err| { format!("Failed to inject debugger plugin into options: {err}") })?, )); } - Ok(options) + Ok(Some(options)) } fn language_server_workspace_configuration(