Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions src/websocket/interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ pub struct InteractiveMode {
status: String,
messages: Vec<MessageEntry>,
format_json: bool,
color_json: bool,
}

#[derive(Debug, Eq, PartialEq)]
Expand All @@ -152,6 +153,10 @@ pub enum InputAction {

impl InteractiveMode {
pub fn new(cols: usize, format_json: bool) -> Self {
Self::new_with_color(cols, format_json, false)
}

pub fn new_with_color(cols: usize, format_json: bool, color_json: bool) -> Self {
Self {
editor: LineEditor::default(),
rows: 0,
Expand All @@ -160,6 +165,7 @@ impl InteractiveMode {
status: "connected".to_string(),
messages: Vec::new(),
format_json,
color_json,
}
}

Expand Down Expand Up @@ -448,10 +454,14 @@ impl InteractiveMode {
pub fn format_message(&self, data: &[u8]) -> Result<String, FetchError> {
if self.format_json
&& serde_json::from_slice::<serde_json::Value>(data).is_ok()
&& let Ok(formatted) = json::format_json_line(data, false)
&& let Ok(formatted) = json::format_json_line(data, self.color_json)
{
let text = String::from_utf8_lossy(&formatted);
return Ok(sanitize_message_text(text.trim_end_matches('\n')));
return Ok(if self.color_json {
text.trim_end_matches('\n').to_string()
} else {
sanitize_message_text(text.trim_end_matches('\n'))
});
}
Ok(sanitize_message_text(&String::from_utf8_lossy(data)))
}
Expand Down Expand Up @@ -644,6 +654,7 @@ pub async fn run_terminal<S>(
stream: S,
initial_message: Option<&[u8]>,
format_json: bool,
color_json: bool,
rows: usize,
cols: usize,
) -> Result<(), FetchError>
Expand All @@ -655,7 +666,7 @@ where
let (input_tx, mut input_rx) = tokio_mpsc::channel(STDIN_CHAN_BUF);
spawn_stdin_reader(input_tx);

let mut mode = InteractiveMode::new(cols, format_json);
let mut mode = InteractiveMode::new_with_color(cols, format_json, color_json);
let (initial_row, mut pending) = detect_cursor_row_async(&mut input_rx, &mut stdout).await?;
mode.setup_screen(&mut stdout, rows, cols, initial_row)?;
stdout.flush()?;
Expand Down Expand Up @@ -1029,14 +1040,26 @@ pub fn wrap_display_lines(text: &str, width: usize) -> Vec<String> {

let mut line = String::new();
let mut line_width = 0;
for ch in part.chars() {
let mut index = 0;
while index < part.len() {
if let Some((sequence, next)) = ansi_csi_sequence(part, index) {
line.push_str(sequence);
index = next;
continue;
}

let ch = part[index..]
.chars()
.next()
.expect("index is inside string bounds");
let char_width = char_display_width(ch).max(1);
if line_width > 0 && line_width + char_width > width {
lines.push(std::mem::take(&mut line));
line_width = 0;
}
line.push(ch);
line_width += char_width;
index += ch.len_utf8();
}
lines.push(line);
}
Expand Down Expand Up @@ -1065,7 +1088,35 @@ pub fn fit_display_width(text: &str, width: usize) -> String {
}

fn display_width(text: &str) -> usize {
text.chars().map(|ch| char_display_width(ch).max(1)).sum()
let mut width = 0;
let mut index = 0;
while index < text.len() {
if let Some((_, next)) = ansi_csi_sequence(text, index) {
index = next;
continue;
}
let ch = text[index..]
.chars()
.next()
.expect("index is inside string bounds");
width += char_display_width(ch).max(1);
index += ch.len_utf8();
}
width
}

fn ansi_csi_sequence(text: &str, start: usize) -> Option<(&str, usize)> {
let bytes = text.as_bytes();
if bytes.get(start) != Some(&b'\x1b') || bytes.get(start + 1) != Some(&b'[') {
return None;
}

for index in start + 2..bytes.len() {
if (0x40..=0x7e).contains(&bytes[index]) {
return Some((&text[start..=index], index + 1));
}
}
None
}

fn char_display_width(ch: char) -> usize {
Expand Down Expand Up @@ -1282,6 +1333,12 @@ mod tests {
4,
vec!["日本", "語"],
),
(
"ansi sgr sequences do not count toward width",
"\x1b[34mabc\x1b[0mdef",
3,
vec!["\x1b[34mabc\x1b[0m", "def"],
),
];

for (name, input, width, want) in cases {
Expand Down Expand Up @@ -1487,6 +1544,15 @@ mod tests {
assert_eq!(formatted, r#"{ "ok": true }"#);
}

#[test]
fn format_message_colors_json_when_enabled() {
let mode = InteractiveMode::new_with_color(80, true, true);
let formatted = mode.format_message(br#"{"ok":"yes"}"#).unwrap();

assert!(formatted.contains("\"\x1b[34m\x1b[1mok\x1b[0m\""));
assert!(formatted.contains("\"\x1b[32myes\x1b[0m\""));
}

#[test]
fn handle_input_submits_text_messages_on_enter() {
let mut mode = InteractiveMode::new(20, false);
Expand Down
33 changes: 25 additions & 8 deletions src/websocket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ pub async fn execute(cli: &Cli) -> Result<i32, FetchError> {
stream,
initial_message.as_deref(),
should_format_for_interactive(cli),
use_color(cli, io::stdout().is_terminal()),
size.rows,
size.cols,
)
Expand Down Expand Up @@ -305,9 +306,12 @@ async fn read_messages<S>(cli: &Cli, stream: &mut S) -> Result<(), FetchError>
where
S: futures_util::Stream<Item = Result<Message, WsError>> + Unpin,
{
let stdout_is_terminal = io::stdout().is_terminal();
while let Some(message) = stream.next().await {
match message.map_err(websocket_error)? {
Message::Text(text) => write_text_message(cli, text.as_str().as_bytes())?,
Message::Text(text) => {
write_text_message(cli, text.as_str().as_bytes(), stdout_is_terminal)?
}
Message::Binary(bytes) => write_binary_indicator(cli, bytes.len()),
Message::Close(_) => return Ok(()),
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {}
Expand All @@ -316,9 +320,9 @@ where
Ok(())
}

fn write_text_message(cli: &Cli, bytes: &[u8]) -> Result<(), FetchError> {
if should_format(cli)
&& let Ok(formatted) = json::format_json_line(bytes, use_color(cli))
fn write_text_message(cli: &Cli, bytes: &[u8], stdout_is_terminal: bool) -> Result<(), FetchError> {
if should_format(cli, stdout_is_terminal)
&& let Ok(formatted) = json::format_json_line(bytes, use_color(cli, stdout_is_terminal))
{
print!("{}", String::from_utf8_lossy(&formatted));
return Ok(());
Expand All @@ -327,20 +331,20 @@ fn write_text_message(cli: &Cli, bytes: &[u8]) -> Result<(), FetchError> {
Ok(())
}

fn should_format(cli: &Cli) -> bool {
fn should_format(cli: &Cli, stdout_is_terminal: bool) -> bool {
match cli.format.as_deref() {
Some("off") => false,
Some("on") => true,
_ => io::stdout().is_terminal(),
_ => stdout_is_terminal,
}
}

fn should_format_for_interactive(cli: &Cli) -> bool {
!matches!(cli.format.as_deref(), Some("off"))
}

fn use_color(cli: &Cli) -> bool {
cli.color.as_deref() == Some("on")
fn use_color(cli: &Cli, stdout_is_terminal: bool) -> bool {
core::color_enabled(cli.color.as_deref(), stdout_is_terminal)
}

fn write_binary_indicator(cli: &Cli, len: usize) {
Expand Down Expand Up @@ -422,4 +426,17 @@ mod tests {
"--ws-interactive on requires stdin, stdout, and stderr to be terminals"
);
}

#[test]
fn websocket_json_color_matches_core_auto_policy() {
let default_cli = Cli::try_parse_from(["fetch", "ws://example.com"]).unwrap();
assert!(use_color(&default_cli, true));
assert!(!use_color(&default_cli, false));

let on_cli = Cli::try_parse_from(["fetch", "--color", "on", "ws://example.com"]).unwrap();
assert!(use_color(&on_cli, false));

let off_cli = Cli::try_parse_from(["fetch", "--color", "off", "ws://example.com"]).unwrap();
assert!(!use_color(&off_cli, true));
}
}
Loading