diff --git a/Cargo.lock b/Cargo.lock index 967d755..8a68334 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -281,6 +281,7 @@ dependencies = [ "dirs", "futures", "glob", + "lazy_static", "miette", "os_pipe", "parking_lot", @@ -604,6 +605,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.158" diff --git a/crates/deno_task_shell/Cargo.toml b/crates/deno_task_shell/Cargo.toml index 9a19ff3..8a3f909 100644 --- a/crates/deno_task_shell/Cargo.toml +++ b/crates/deno_task_shell/Cargo.toml @@ -29,6 +29,7 @@ pest_derive = "2.7.12" dirs = "5.0.1" pest_ascii_tree = { git = "https://github.com/prsabahrami/pest_ascii_tree.git", branch = "master" } miette = "7.2.0" +lazy_static = "1.4.0" [dev-dependencies] tempfile = "3.12.0" diff --git a/crates/deno_task_shell/src/grammar.pest b/crates/deno_task_shell/src/grammar.pest index 8a98850..c3e081f 100644 --- a/crates/deno_task_shell/src/grammar.pest +++ b/crates/deno_task_shell/src/grammar.pest @@ -3,6 +3,8 @@ // Whitespace and comments WHITESPACE = _{ " " | "\t" | ("\\" ~ WHITESPACE* ~ NEWLINE) } COMMENT = _{ "#" ~ (!NEWLINE ~ ANY)* } +NUMBER = @{ INT ~ ("." ~ ASCII_DIGIT*)? ~ (^"e" ~ INT)? } +INT = { ("+" | "-")? ~ ASCII_DIGIT+ } // Basic tokens QUOTED_WORD = { DOUBLE_QUOTED | SINGLE_QUOTED } @@ -11,6 +13,7 @@ UNQUOTED_PENDING_WORD = ${ (TILDE_PREFIX ~ (!(OPERATOR | WHITESPACE | NEWLINE) ~ ( EXIT_STATUS | UNQUOTED_ESCAPE_CHAR | + "$" ~ ARITHMETIC_EXPRESSION | SUB_COMMAND | ("$" ~ "{" ~ VARIABLE ~ "}" | "$" ~ VARIABLE) | UNQUOTED_CHAR | @@ -20,6 +23,7 @@ UNQUOTED_PENDING_WORD = ${ (!(OPERATOR | WHITESPACE | NEWLINE) ~ ( EXIT_STATUS | UNQUOTED_ESCAPE_CHAR | + "$" ~ ARITHMETIC_EXPRESSION | SUB_COMMAND | ("$" ~ "{" ~ VARIABLE ~ "}" | "$" ~ VARIABLE) | UNQUOTED_CHAR | @@ -69,7 +73,7 @@ QUOTED_ESCAPE_CHAR = ${ "\\" ~ "$" | "$" ~ !"(" ~ !"{" ~ !VARIABLE | "\\" ~ ("`" UNQUOTED_CHAR = ${ ("\\" ~ " ") | !("]]" | "[[" | "(" | ")" | "<" | ">" | "|" | "&" | ";" | "\"" | "'" | "$") ~ ANY } QUOTED_CHAR = ${ !"\"" ~ ANY } -VARIABLE = ${ (ASCII_ALPHANUMERIC | "_")+ } +VARIABLE = ${ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } SUB_COMMAND = { "$(" ~ complete_command ~ ")"} DOUBLE_QUOTED = @{ "\"" ~ QUOTED_PENDING_WORD ~ "\"" } @@ -148,6 +152,7 @@ command = !{ compound_command = { brace_group | + ARITHMETIC_EXPRESSION | subshell | for_clause | case_clause | @@ -156,7 +161,86 @@ compound_command = { until_clause } -subshell = !{ "(" ~ compound_list ~ ")" } +ARITHMETIC_EXPRESSION = !{ "((" ~ arithmetic_sequence ~ "))" } +arithmetic_sequence = !{ arithmetic_expr ~ ("," ~ arithmetic_expr)* } +arithmetic_expr = { parentheses_expr | variable_assignment | triple_conditional_expr | binary_arithmetic_expr | binary_conditional_expression | unary_arithmetic_expr | VARIABLE | NUMBER } +parentheses_expr = !{ "(" ~ arithmetic_sequence ~ ")" } + +variable_assignment = !{ + VARIABLE ~ assignment_operator ~ arithmetic_expr +} + +triple_conditional_expr = !{ + (parentheses_expr | variable_assignment | binary_arithmetic_expr | binary_conditional_expression | unary_arithmetic_expr | VARIABLE | NUMBER) ~ + "?" ~ (parentheses_expr | variable_assignment | binary_arithmetic_expr | binary_conditional_expression | unary_arithmetic_expr | VARIABLE | NUMBER) ~ + ":" ~ (parentheses_expr | variable_assignment | binary_arithmetic_expr | binary_conditional_expression | unary_arithmetic_expr | VARIABLE | NUMBER) +} + +binary_arithmetic_expr = _{ + (parentheses_expr | binary_conditional_expression | unary_arithmetic_expr | variable_assignment | VARIABLE | NUMBER) ~ + (binary_arithmetic_op ~ + (parentheses_expr | variable_assignment | binary_conditional_expression | unary_arithmetic_expr | VARIABLE | NUMBER) + )+ +} + +binary_arithmetic_op = _{ + add | subtract | power | multiply | divide | modulo | left_shift | right_shift | + bitwise_and | bitwise_xor | bitwise_or | logical_and | logical_or +} + +add = { "+" } +subtract = { "-" } +multiply = { "*" } +divide = { "/" } +modulo = { "%" } +power = { "**" } +left_shift = { "<<" } +right_shift = { ">>" } +bitwise_and = { "&" } +bitwise_xor = { "^" } +bitwise_or = { "|" } +logical_and = { "&&" } +logical_or = { "||" } + +unary_arithmetic_expr = !{ + (unary_arithmetic_op | post_arithmetic_op) ~ (parentheses_expr | VARIABLE | NUMBER) | + (parentheses_expr | VARIABLE | NUMBER) ~ post_arithmetic_op +} + +unary_arithmetic_op = _{ + unary_plus | unary_minus | logical_not | bitwise_not +} + +unary_plus = { "+" } +unary_minus = { "-" } +logical_not = { "!" } +bitwise_not = { "~" } + +post_arithmetic_op = !{ + increment | decrement +} + +increment = { "++" } +decrement = { "--" } + +assignment_operator = _{ + assign | multiply_assign | divide_assign | modulo_assign | add_assign | subtract_assign | + left_shift_assign | right_shift_assign | bitwise_and_assign | bitwise_xor_assign | bitwise_or_assign +} + +assign = { "=" } +multiply_assign = { "*=" } +divide_assign = { "/=" } +modulo_assign = { "%=" } +add_assign = { "+=" } +subtract_assign = { "-=" } +left_shift_assign = { "<<=" } +right_shift_assign = { ">>=" } +bitwise_and_assign = { "&=" } +bitwise_xor_assign = { "^=" } +bitwise_or_assign = { "|=" } + +subshell = !{ "(" ~ compound_list ~ ")" } compound_list = !{ (newline_list? ~ term ~ separator?)+ } term = !{ and_or ~ (separator ~ and_or)* } @@ -232,16 +316,16 @@ string_conditional_op = !{ binary_conditional_expression = !{ UNQUOTED_PENDING_WORD ~ ( - binary_string_conditional_op | - binary_arithmetic_conditional_op + binary_bash_conditional_op | + binary_posix_conditional_op ) ~ UNQUOTED_PENDING_WORD } -binary_string_conditional_op = !{ +binary_bash_conditional_op = !{ "==" | "=" | "!=" | "<" | ">" } -binary_arithmetic_conditional_op = !{ +binary_posix_conditional_op = !{ "-eq" | "-ne" | "-lt" | "-le" | "-gt" | "-ge" } diff --git a/crates/deno_task_shell/src/parser.rs b/crates/deno_task_shell/src/parser.rs index 16f3c31..c438ece 100644 --- a/crates/deno_task_shell/src/parser.rs +++ b/crates/deno_task_shell/src/parser.rs @@ -1,7 +1,9 @@ // Copyright 2018-2024 the Deno authors. MIT license. +use lazy_static::lazy_static; use miette::{miette, Context, Result}; use pest::iterators::Pair; +use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use thiserror::Error; @@ -109,16 +111,6 @@ pub struct BooleanList { pub next: Sequence, } -#[cfg_attr(feature = "serialization", derive(serde::Serialize))] -#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Error)] -pub enum PipeSequenceOperator { - #[error("Stdout pipe operator")] - Stdout, - #[error("Stdout and stderr pipe operator")] - StdoutStderr, -} - #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] #[derive(Debug, Clone, PartialEq, Eq, Error)] @@ -138,6 +130,16 @@ impl From for Sequence { } } +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Error)] +pub enum PipeSequenceOperator { + #[error("Stdout pipe operator")] + Stdout, + #[error("Stdout and stderr pipe operator")] + StdoutStderr, +} + #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] #[derive(Debug, Clone, PartialEq, Eq, Error)] @@ -160,6 +162,8 @@ pub enum CommandInner { Subshell(Box), #[error("Invalid if command")] If(IfClause), + #[error("Invalid arithmetic expression")] + ArithmeticExpression(Arithmetic), } impl From for Sequence { @@ -378,10 +382,120 @@ pub enum WordPart { Quoted(Vec), #[error("Invalid tilde prefix")] Tilde(TildePrefix), + #[error("Invalid arithmetic expression")] + Arithmetic(Arithmetic), #[error("Invalid exit status")] ExitStatus, } +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[error("Invalid arithmetic sequence")] +pub struct Arithmetic { + pub parts: Vec, +} +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[error("Invalid arithmetic part")] +pub enum ArithmeticPart { + #[error("Invalid parentheses expression")] + ParenthesesExpr(Box), + #[error("Invalid variable assignment")] + VariableAssignment { + name: String, + op: AssignmentOp, + value: Box, + }, + #[error("Invalid triple conditional expression")] + TripleConditionalExpr { + condition: Box, + true_expr: Box, + false_expr: Box, + }, + #[error("Invalid binary arithmetic expression")] + BinaryArithmeticExpr { + left: Box, + operator: BinaryArithmeticOp, + right: Box, + }, + #[error("Invalid binary conditional expression")] + BinaryConditionalExpr { + left: Box, + operator: BinaryOp, + right: Box, + }, + #[error("Invalid unary arithmetic expression")] + UnaryArithmeticExpr { + operator: UnaryArithmeticOp, + operand: Box, + }, + #[error("Invalid post arithmetic expression")] + PostArithmeticExpr { + operand: Box, + operator: PostArithmeticOp, + }, + #[error("Invalid variable")] + Variable(String), + #[error("Invalid number")] + Number(String), +} + +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, Copy, Ord)] +pub enum BinaryArithmeticOp { + Add, // + + Subtract, // - + Multiply, // * + Divide, // / + Modulo, // % + Power, // ** + LeftShift, // << + RightShift, // >> + BitwiseAnd, // & + BitwiseXor, // ^ + BitwiseOr, // | + LogicalAnd, // && + LogicalOr, // || +} + +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum AssignmentOp { + Assign, // = + MultiplyAssign, // *= + DivideAssign, // /= + ModuloAssign, // %= + AddAssign, // += + SubtractAssign, // -= + LeftShiftAssign, // <<= + RightShiftAssign, // >>= + BitwiseAndAssign, // &= + BitwiseXorAssign, // ^= + BitwiseOrAssign, // |= +} + +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum UnaryArithmeticOp { + Plus, // + + Minus, // - + LogicalNot, // ! + BitwiseNot, // ~ +} + +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PostArithmeticOp { + Increment, // ++ + Decrement, // -- +} + #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr( feature = "serialization", @@ -449,6 +563,41 @@ pub enum RedirectOpOutput { Append, } +lazy_static! { + static ref ARITHMETIC_PARSER: PrattParser = { + use Assoc::*; + use Rule::*; + + PrattParser::new() + .op( + Op::infix(assign, Right) + | Op::infix(multiply_assign, Right) + | Op::infix(divide_assign, Right) + | Op::infix(modulo_assign, Right) + | Op::infix(add_assign, Right) + | Op::infix(subtract_assign, Right) + | Op::infix(left_shift_assign, Right) + | Op::infix(right_shift_assign, Right) + | Op::infix(bitwise_and_assign, Right) + | Op::infix(bitwise_xor_assign, Right) + | Op::infix(bitwise_or_assign, Right), + ) + .op(Op::infix(logical_or, Left)) + .op(Op::infix(logical_and, Left)) + .op(Op::infix(bitwise_or, Left)) + .op(Op::infix(bitwise_xor, Left)) + .op(Op::infix(bitwise_and, Left)) + .op(Op::infix(left_shift, Left) | Op::infix(right_shift, Left)) + .op(Op::infix(add, Left) | Op::infix(subtract, Left)) + .op( + Op::infix(multiply, Left) + | Op::infix(divide, Left) + | Op::infix(modulo, Left), + ) + .op(Op::infix(power, Right)) + }; +} + #[derive(Parser)] #[grammar = "grammar.pest"] struct ShellParser; @@ -792,6 +941,13 @@ fn parse_compound_command(pair: Pair) -> Result { Rule::until_clause => { Err(miette!("Unsupported compound command until_clause")) } + Rule::ARITHMETIC_EXPRESSION => { + let arithmetic_expression = parse_arithmetic_expression(inner)?; + Ok(Command { + inner: CommandInner::ArithmeticExpression(arithmetic_expression), + redirect: None, + }) + } _ => Err(miette!( "Unexpected rule in compound_command: {:?}", inner.as_rule() @@ -985,7 +1141,7 @@ fn parse_binary_conditional_expression(pair: Pair) -> Result { let right_word = parse_word(right)?; let op = match operator.as_rule() { - Rule::binary_string_conditional_op => match operator.as_str() { + Rule::binary_bash_conditional_op => match operator.as_str() { "==" => BinaryOp::Equal, "=" => BinaryOp::Equal, "!=" => BinaryOp::NotEqual, @@ -998,7 +1154,7 @@ fn parse_binary_conditional_expression(pair: Pair) -> Result { )) } }, - Rule::binary_arithmetic_conditional_op => match operator.as_str() { + Rule::binary_posix_conditional_op => match operator.as_str() { "-eq" => BinaryOp::Equal, "-ne" => BinaryOp::NotEqual, "-lt" => BinaryOp::LessThan, @@ -1085,6 +1241,10 @@ fn parse_word(pair: Pair) -> Result { let tilde_prefix = parse_tilde_prefix(part)?; parts.push(tilde_prefix); } + Rule::ARITHMETIC_EXPRESSION => { + let arithmetic_expression = parse_arithmetic_expression(part)?; + parts.push(WordPart::Arithmetic(arithmetic_expression)); + } _ => { return Err(miette!( "Unexpected rule in UNQUOTED_PENDING_WORD: {:?}", @@ -1130,6 +1290,10 @@ fn parse_word(pair: Pair) -> Result { let tilde_prefix = parse_tilde_prefix(part)?; parts.push(tilde_prefix); } + Rule::ARITHMETIC_EXPRESSION => { + let arithmetic_expression = parse_arithmetic_expression(part)?; + parts.push(WordPart::Arithmetic(arithmetic_expression)); + } _ => { return Err(miette!( "Unexpected rule in FILE_NAME_PENDING_WORD: {:?}", @@ -1151,6 +1315,164 @@ fn parse_word(pair: Pair) -> Result { } } +fn parse_arithmetic_expression(pair: Pair) -> Result { + assert!(pair.as_rule() == Rule::ARITHMETIC_EXPRESSION); + let inner = pair.into_inner().next().unwrap(); + let parts = parse_arithmetic_sequence(inner)?; + Ok(Arithmetic { parts }) +} + +fn parse_arithmetic_sequence(pair: Pair) -> Result> { + assert!(pair.as_rule() == Rule::arithmetic_sequence); + let mut parts = Vec::new(); + for expr in pair.into_inner() { + parts.push(parse_arithmetic_expr(expr)?); + } + Ok(parts) +} + +fn parse_arithmetic_expr(pair: Pair) -> Result { + ARITHMETIC_PARSER + .map_primary(|primary| match primary.as_rule() { + Rule::parentheses_expr => { + let inner = primary.into_inner().next().unwrap(); + let parts = parse_arithmetic_sequence(inner)?; + Ok(ArithmeticPart::ParenthesesExpr(Box::new(Arithmetic { + parts, + }))) + } + Rule::variable_assignment => { + let mut inner = primary.into_inner(); + let name = inner.next().unwrap().as_str().to_string(); + let op = inner.next().unwrap(); + + let value = parse_arithmetic_expr(inner.next().unwrap())?; + Ok(ArithmeticPart::VariableAssignment { + name, + op: match op.as_rule() { + Rule::assign => AssignmentOp::Assign, + Rule::multiply_assign => AssignmentOp::MultiplyAssign, + Rule::divide_assign => AssignmentOp::DivideAssign, + Rule::modulo_assign => AssignmentOp::ModuloAssign, + Rule::add_assign => AssignmentOp::AddAssign, + Rule::subtract_assign => AssignmentOp::SubtractAssign, + Rule::left_shift_assign => AssignmentOp::LeftShiftAssign, + Rule::right_shift_assign => AssignmentOp::RightShiftAssign, + _ => { + return Err(miette!( + "Unexpected assignment operator: {:?}", + op.as_rule() + )); + } + }, + value: Box::new(value), + }) + } + Rule::triple_conditional_expr => { + let mut inner = primary.into_inner(); + let condition = parse_arithmetic_expr(inner.next().unwrap())?; + let true_expr = parse_arithmetic_expr(inner.next().unwrap())?; + let false_expr = parse_arithmetic_expr(inner.next().unwrap())?; + Ok(ArithmeticPart::TripleConditionalExpr { + condition: Box::new(condition), + true_expr: Box::new(true_expr), + false_expr: Box::new(false_expr), + }) + } + Rule::unary_arithmetic_expr => parse_unary_arithmetic_expr(primary), + Rule::VARIABLE => { + Ok(ArithmeticPart::Variable(primary.as_str().to_string())) + } + Rule::NUMBER => Ok(ArithmeticPart::Number(primary.as_str().to_string())), + _ => Err(miette!( + "Unexpected rule in arithmetic expression: {:?}", + primary.as_rule() + )), + }) + .map_infix(|lhs, op, rhs| { + let operator = match op.as_rule() { + Rule::add => BinaryArithmeticOp::Add, + Rule::subtract => BinaryArithmeticOp::Subtract, + Rule::multiply => BinaryArithmeticOp::Multiply, + Rule::divide => BinaryArithmeticOp::Divide, + Rule::modulo => BinaryArithmeticOp::Modulo, + Rule::power => BinaryArithmeticOp::Power, + Rule::left_shift => BinaryArithmeticOp::LeftShift, + Rule::right_shift => BinaryArithmeticOp::RightShift, + Rule::bitwise_and => BinaryArithmeticOp::BitwiseAnd, + Rule::bitwise_xor => BinaryArithmeticOp::BitwiseXor, + Rule::bitwise_or => BinaryArithmeticOp::BitwiseOr, + Rule::logical_and => BinaryArithmeticOp::LogicalAnd, + Rule::logical_or => BinaryArithmeticOp::LogicalOr, + _ => { + return Err(miette!("Unexpected infix operator: {:?}", op.as_rule())) + } + }; + Ok(ArithmeticPart::BinaryArithmeticExpr { + left: Box::new(lhs?), + operator, + right: Box::new(rhs?), + }) + }) + .parse(pair.into_inner()) +} + +fn parse_unary_arithmetic_expr(pair: Pair) -> Result { + let mut inner = pair.into_inner(); + let first = inner.next().unwrap(); + + match first.as_rule() { + Rule::unary_arithmetic_op => { + let op = parse_unary_arithmetic_op(first)?; + let operand = parse_arithmetic_expr(inner.next().unwrap())?; + Ok(ArithmeticPart::UnaryArithmeticExpr { + operator: op, + operand: Box::new(operand), + }) + } + Rule::post_arithmetic_op => { + let operand = parse_arithmetic_expr(inner.next().unwrap())?; + let op = parse_post_arithmetic_op(first)?; + Ok(ArithmeticPart::PostArithmeticExpr { + operand: Box::new(operand), + operator: op, + }) + } + _ => { + let operand = parse_arithmetic_expr(first)?; + let op = parse_post_arithmetic_op(inner.next().unwrap())?; + Ok(ArithmeticPart::PostArithmeticExpr { + operand: Box::new(operand), + operator: op, + }) + } + } +} + +fn parse_unary_arithmetic_op(pair: Pair) -> Result { + match pair.as_str() { + "+" => Ok(UnaryArithmeticOp::Plus), + "-" => Ok(UnaryArithmeticOp::Minus), + "!" => Ok(UnaryArithmeticOp::LogicalNot), + "~" => Ok(UnaryArithmeticOp::BitwiseNot), + _ => Err(miette!( + "Invalid unary arithmetic operator: {}", + pair.as_str() + )), + } +} + +fn parse_post_arithmetic_op(pair: Pair) -> Result { + match pair.as_str() { + "++" => Ok(PostArithmeticOp::Increment), + "--" => Ok(PostArithmeticOp::Decrement), + _ => Err(miette!( + "Invalid post arithmetic operator: {}", + pair.as_str() + )), + } +} + fn parse_tilde_prefix(pair: Pair) -> Result { let tilde_prefix_str = pair.as_str(); let user = if tilde_prefix_str.len() > 1 { @@ -1384,7 +1706,6 @@ mod test { .map_err(|e| miette!(e.to_string()))? .next() .unwrap(); - // println!("pairs: {:?}", pairs); parse_complete_command(pairs) }; diff --git a/crates/deno_task_shell/src/shell/command.rs b/crates/deno_task_shell/src/shell/command.rs index 462852d..b2a319c 100644 --- a/crates/deno_task_shell/src/shell/command.rs +++ b/crates/deno_task_shell/src/shell/command.rs @@ -194,19 +194,23 @@ async fn parse_shebang_args( crate::parser::CommandInner::Simple(cmd) => cmd, crate::parser::CommandInner::Subshell(_) => return err_unsupported(text), crate::parser::CommandInner::If(_) => return err_unsupported(text), + crate::parser::CommandInner::ArithmeticExpression(_) => { + return err_unsupported(text) + } }; if !cmd.env_vars.is_empty() { return err_unsupported(text); } - super::execute::evaluate_args( + let result = super::execute::evaluate_args( cmd.args, &context.state, context.stdin.clone(), context.stderr.clone(), ) .await - .map_err(|e| miette!(e.to_string())) + .map_err(|e| miette!(e.to_string()))?; + Ok(result.value) } /// Errors for executable commands. diff --git a/crates/deno_task_shell/src/shell/execute.rs b/crates/deno_task_shell/src/shell/execute.rs index 13de4a4..f2accae 100644 --- a/crates/deno_task_shell/src/shell/execute.rs +++ b/crates/deno_task_shell/src/shell/execute.rs @@ -5,6 +5,7 @@ use std::path::Path; use std::rc::Rc; use anyhow::Context; +use anyhow::Error; use futures::future; use futures::future::LocalBoxFuture; use futures::FutureExt; @@ -12,6 +13,7 @@ use thiserror::Error; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +use crate::parser::AssignmentOp; use crate::parser::BinaryOp; use crate::parser::Condition; use crate::parser::ConditionInner; @@ -23,6 +25,8 @@ use crate::parser::UnaryOp; use crate::shell::commands::ShellCommand; use crate::shell::commands::ShellCommandContext; use crate::shell::types::pipe; +use crate::shell::types::ArithmeticResult; +use crate::shell::types::ArithmeticValue; use crate::shell::types::EnvChange; use crate::shell::types::ExecuteResult; use crate::shell::types::FutureExecuteResult; @@ -30,6 +34,9 @@ use crate::shell::types::ShellPipeReader; use crate::shell::types::ShellPipeWriter; use crate::shell::types::ShellState; +use crate::parser::Arithmetic; +use crate::parser::ArithmeticPart; +use crate::parser::BinaryArithmeticOp; use crate::parser::Command; use crate::parser::CommandInner; use crate::parser::IfClause; @@ -43,10 +50,10 @@ use crate::parser::RedirectOp; use crate::parser::Sequence; use crate::parser::SequentialList; use crate::parser::SimpleCommand; +use crate::parser::UnaryArithmeticOp; use crate::parser::Word; use crate::parser::WordPart; -// use crate::parser::ElsePart; -// use crate::parser::ElifClause; +use crate::shell::types::WordEvalResult; use super::command::execute_unresolved_command_name; use super::command::UnresolvedCommandName; @@ -428,21 +435,21 @@ async fn resolve_redirect_word_pipe( } }; // edge case that's not supported - if words.is_empty() { + if words.value.is_empty() { let _ = stderr.write_line("redirect path must be 1 argument, but found 0"); return Err(ExecuteResult::from_exit_code(1)); - } else if words.len() > 1 { + } else if words.value.len() > 1 { let _ = stderr.write_line(&format!( concat!( "redirect path must be 1 argument, but found {0} ({1}). ", "Did you mean to quote it (ex. \"{1}\")?" ), - words.len(), + words.value.len(), words.join(" ") )); return Err(ExecuteResult::from_exit_code(1)); } - let output_path = &words[0]; + let output_path = &words.value[0]; match &redirect_op { RedirectOp::Input(RedirectOpInput::Redirect) => { @@ -480,7 +487,7 @@ async fn execute_command( stdout: ShellPipeWriter, mut stderr: ShellPipeWriter, ) -> ExecuteResult { - let (stdin, stdout, stderr) = if let Some(redirect) = &command.redirect { + let (stdin, stdout, mut stderr) = if let Some(redirect) = &command.redirect { let pipe = match resolve_redirect_pipe( redirect, &state, @@ -528,6 +535,212 @@ async fn execute_command( CommandInner::If(if_clause) => { execute_if_clause(if_clause, state, stdin, stdout, stderr).await } + CommandInner::ArithmeticExpression(arithmetic) => { + match execute_arithmetic_expression(arithmetic, state).await { + Ok(result) => ExecuteResult::Continue(0, result.changes, Vec::new()), + Err(e) => { + let _ = stderr.write_line(&e.to_string()); + ExecuteResult::Continue(2, Vec::new(), Vec::new()) + } + } + } + } +} + +async fn execute_arithmetic_expression( + arithmetic: Arithmetic, + mut state: ShellState, +) -> Result { + evaluate_arithmetic(&arithmetic, &mut state).await +} + +async fn evaluate_arithmetic( + arithmetic: &Arithmetic, + state: &mut ShellState, +) -> Result { + let mut result = ArithmeticResult::new(ArithmeticValue::Integer(0)); + for part in &arithmetic.parts { + result = Box::pin(evaluate_arithmetic_part(part, state)).await?; + } + Ok(result) +} + +async fn evaluate_arithmetic_part( + part: &ArithmeticPart, + state: &mut ShellState, +) -> Result { + match part { + ArithmeticPart::ParenthesesExpr(expr) => { + Box::pin(evaluate_arithmetic(expr, state)).await + } + ArithmeticPart::VariableAssignment { name, op, value } => { + let val = Box::pin(evaluate_arithmetic_part(value, state)).await?; + let applied_value = match op { + AssignmentOp::Assign => val.clone(), + _ => { + let var = state + .get_var(name) + .ok_or_else(|| anyhow::anyhow!("Undefined variable: {}", name))?; + let parsed_var = var.parse::().map_err(|e| { + anyhow::anyhow!("Failed to parse variable '{}': {}", name, e) + })?; + match op { + AssignmentOp::MultiplyAssign => val.checked_mul(&parsed_var), + AssignmentOp::DivideAssign => val.checked_div(&parsed_var), + AssignmentOp::ModuloAssign => val.checked_rem(&parsed_var), + AssignmentOp::AddAssign => val.checked_add(&parsed_var), + AssignmentOp::SubtractAssign => val.checked_sub(&parsed_var), + AssignmentOp::LeftShiftAssign => val.checked_shl(&parsed_var), + AssignmentOp::RightShiftAssign => val.checked_shr(&parsed_var), + AssignmentOp::BitwiseAndAssign => val.checked_and(&parsed_var), + AssignmentOp::BitwiseXorAssign => val.checked_xor(&parsed_var), + AssignmentOp::BitwiseOrAssign => val.checked_or(&parsed_var), + _ => unreachable!(), + }? + } + }; + state.apply_env_var(name, &applied_value.to_string()); + Ok( + applied_value + .clone() + .with_changes(vec![EnvChange::SetShellVar( + name.clone(), + applied_value.to_string(), + )]), + ) + } + ArithmeticPart::TripleConditionalExpr { + condition, + true_expr, + false_expr, + } => { + let cond = Box::pin(evaluate_arithmetic_part(condition, state)).await?; + if cond.is_zero() { + Box::pin(evaluate_arithmetic_part(true_expr, state)).await + } else { + Box::pin(evaluate_arithmetic_part(false_expr, state)).await + } + } + ArithmeticPart::BinaryArithmeticExpr { + left, + operator, + right, + } => { + let lhs = Box::pin(evaluate_arithmetic_part(left, state)).await?; + let rhs = Box::pin(evaluate_arithmetic_part(right, state)).await?; + apply_binary_op(lhs, *operator, rhs) + } + ArithmeticPart::BinaryConditionalExpr { + left, + operator, + right, + } => { + let lhs = Box::pin(evaluate_arithmetic_part(left, state)).await?; + let rhs = Box::pin(evaluate_arithmetic_part(right, state)).await?; + apply_conditional_binary_op(lhs, operator, rhs) + } + ArithmeticPart::UnaryArithmeticExpr { operator, operand } => { + let val = Box::pin(evaluate_arithmetic_part(operand, state)).await?; + apply_unary_op(*operator, val) + } + ArithmeticPart::PostArithmeticExpr { operand, .. } => { + let val = Box::pin(evaluate_arithmetic_part(operand, state)).await?; + Ok(val) + } + ArithmeticPart::Variable(name) => state + .get_var(name) + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| { + anyhow::anyhow!("Undefined or non-integer variable: {}", name) + }), + ArithmeticPart::Number(num_str) => num_str + .parse::() + .map_err(|e| anyhow::anyhow!(e.to_string())), + } +} + +fn apply_binary_op( + lhs: ArithmeticResult, + op: BinaryArithmeticOp, + rhs: ArithmeticResult, +) -> Result { + match op { + BinaryArithmeticOp::Add => lhs.checked_add(&rhs), + BinaryArithmeticOp::Subtract => lhs.checked_sub(&rhs), + BinaryArithmeticOp::Multiply => lhs.checked_mul(&rhs), + BinaryArithmeticOp::Divide => lhs.checked_div(&rhs), + BinaryArithmeticOp::Modulo => lhs.checked_rem(&rhs), + BinaryArithmeticOp::Power => lhs.checked_pow(&rhs), + BinaryArithmeticOp::LeftShift => lhs.checked_shl(&rhs), + BinaryArithmeticOp::RightShift => lhs.checked_shr(&rhs), + BinaryArithmeticOp::BitwiseAnd => lhs.checked_and(&rhs), + BinaryArithmeticOp::BitwiseXor => lhs.checked_xor(&rhs), + BinaryArithmeticOp::BitwiseOr => lhs.checked_or(&rhs), + BinaryArithmeticOp::LogicalAnd => Ok(if lhs.is_zero() && rhs.is_zero() { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + }), + BinaryArithmeticOp::LogicalOr => Ok(if !lhs.is_zero() || !rhs.is_zero() { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + } +} + +fn apply_conditional_binary_op( + lhs: ArithmeticResult, + op: &BinaryOp, + rhs: ArithmeticResult, +) -> Result { + match op { + BinaryOp::Equal => Ok(if lhs == rhs { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + BinaryOp::NotEqual => Ok(if lhs != rhs { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + BinaryOp::LessThan => Ok(if lhs < rhs { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + BinaryOp::LessThanOrEqual => Ok(if lhs <= rhs { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + BinaryOp::GreaterThan => Ok(if lhs > rhs { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + BinaryOp::GreaterThanOrEqual => Ok(if lhs >= rhs { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + } +} + +fn apply_unary_op( + op: UnaryArithmeticOp, + val: ArithmeticResult, +) -> Result { + match op { + UnaryArithmeticOp::Plus => Ok(val), + UnaryArithmeticOp::Minus => val.checked_neg(), + UnaryArithmeticOp::LogicalNot => Ok(if val.is_zero() { + ArithmeticResult::new(ArithmeticValue::Integer(1)) + } else { + ArithmeticResult::new(ArithmeticValue::Integer(0)) + }), + UnaryArithmeticOp::BitwiseNot => val.checked_not(), } } @@ -751,8 +964,8 @@ async fn execute_simple_command( ) -> ExecuteResult { let args = evaluate_args(command.args, &state, stdin.clone(), stderr.clone()).await; - let args = match args { - Ok(args) => args, + let (args, changes) = match args { + Ok(args) => (args.value, args.changes), Err(err) => { return err.into_exit_code(&mut stderr); } @@ -769,7 +982,15 @@ async fn execute_simple_command( }; state.apply_env_var(&env_var.name, &value); } - execute_command_args(args, state, stdin, stdout, stderr).await + let result = execute_command_args(args, state, stdin, stdout, stderr).await; + match result { + ExecuteResult::Exit(code, handles) => ExecuteResult::Exit(code, handles), + ExecuteResult::Continue(code, env_changes, handles) => { + let mut combined_changes = env_changes.clone(); + combined_changes.extend(changes); + ExecuteResult::Continue(code, combined_changes, handles) + } + } } fn execute_command_args( @@ -843,8 +1064,8 @@ pub async fn evaluate_args( state: &ShellState, stdin: ShellPipeReader, stderr: ShellPipeWriter, -) -> Result, EvaluateWordTextError> { - let mut result = Vec::new(); +) -> Result { + let mut result = WordEvalResult::new(Vec::new(), Vec::new()); for arg in args { let parts = evaluate_word_parts( arg.into_parts(), @@ -902,7 +1123,7 @@ fn evaluate_word_parts( state: &ShellState, stdin: ShellPipeReader, stderr: ShellPipeWriter, -) -> LocalBoxFuture, EvaluateWordTextError>> { +) -> LocalBoxFuture> { #[derive(Debug)] enum TextPart { Quoted(String), @@ -931,7 +1152,7 @@ fn evaluate_word_parts( state: &ShellState, text_parts: Vec, is_quoted: bool, - ) -> Result, EvaluateWordTextError> { + ) -> Result { if !is_quoted && text_parts .iter() @@ -1001,13 +1222,16 @@ fn evaluate_word_parts( }) .collect::>() }; - Ok(paths) + Ok(WordEvalResult::new(paths, Vec::new())) } } Err(err) => Err(EvaluateWordTextError::InvalidPattern { pattern, err }), } } else { - Ok(vec![text_parts_to_string(text_parts)]) + Ok(WordEvalResult { + value: vec![text_parts_to_string(text_parts)], + changes: Vec::new(), + }) } } @@ -1017,10 +1241,12 @@ fn evaluate_word_parts( state: &ShellState, stdin: ShellPipeReader, stderr: ShellPipeWriter, - ) -> LocalBoxFuture, EvaluateWordTextError>> { + ) -> LocalBoxFuture> { // recursive async, so requires boxing + let mut changes: Vec = Vec::new(); + async move { - let mut result = Vec::new(); + let mut result = WordEvalResult::new(Vec::new(), Vec::new()); let mut current_text = Vec::new(); for part in parts { let evaluation_result_text = match part { @@ -1067,6 +1293,13 @@ fn evaluate_word_parts( } continue; } + WordPart::Arithmetic(arithmetic) => { + let arithmetic_result = + execute_arithmetic_expression(arithmetic, state.clone()).await?; + current_text.push(TextPart::Text(arithmetic_result.to_string())); + changes.extend(arithmetic_result.changes); + continue; + } WordPart::ExitStatus => { let exit_code = state.last_command_exit_code(); current_text.push(TextPart::Text(exit_code.to_string())); diff --git a/crates/deno_task_shell/src/shell/types.rs b/crates/deno_task_shell/src/shell/types.rs index 5189e15..bf079d9 100644 --- a/crates/deno_task_shell/src/shell/types.rs +++ b/crates/deno_task_shell/src/shell/types.rs @@ -2,13 +2,17 @@ use std::borrow::Cow; use std::collections::HashMap; +use std::fmt; +use std::fmt::Display; use std::fs; use std::io::Read; use std::io::Write; use std::path::Path; use std::path::PathBuf; use std::rc::Rc; +use std::str::FromStr; +use anyhow::Error; use anyhow::Result; use futures::future::LocalBoxFuture; use tokio::task::JoinHandle; @@ -280,7 +284,7 @@ impl ShellState { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone, PartialOrd)] pub enum EnvChange { /// `export ENV_VAR=VALUE` SetEnvVar(String, String), @@ -527,3 +531,585 @@ pub fn pipe() -> (ShellPipeReader, ShellPipeWriter) { ShellPipeWriter::OsPipe(writer), ) } + +#[derive(Debug, Clone, PartialEq, PartialOrd, thiserror::Error)] +pub struct ArithmeticResult { + pub value: ArithmeticValue, + pub changes: Vec, +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, thiserror::Error)] +pub enum ArithmeticValue { + Float(f64), + Integer(i64), +} + +impl Display for ArithmeticResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +impl Display for ArithmeticValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ArithmeticValue::Float(val) => write!(f, "{}", val), + ArithmeticValue::Integer(val) => write!(f, "{}", val), + } + } +} + +impl ArithmeticResult { + pub fn new(value: ArithmeticValue) -> Self { + ArithmeticResult { + value, + changes: Vec::new(), + } + } + + pub fn is_zero(&self) -> bool { + match &self.value { + ArithmeticValue::Integer(val) => *val == 0, + ArithmeticValue::Float(val) => *val == 0.0, + } + } + + pub fn checked_add( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => lhs + .checked_add(*rhs) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} + {}", lhs, rhs) + })?, + (ArithmeticValue::Float(lhs), ArithmeticValue::Float(rhs)) => { + let sum = lhs + rhs; + if sum.is_finite() { + ArithmeticValue::Float(sum) + } else { + return Err(anyhow::anyhow!("Float overflow: {} + {}", lhs, rhs)); + } + } + (ArithmeticValue::Integer(lhs), ArithmeticValue::Float(rhs)) + | (ArithmeticValue::Float(rhs), ArithmeticValue::Integer(lhs)) => { + let sum = *lhs as f64 + rhs; + if sum.is_finite() { + ArithmeticValue::Float(sum) + } else { + return Err(anyhow::anyhow!("Float overflow: {} + {}", lhs, rhs)); + } + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_sub( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => lhs + .checked_sub(*rhs) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} - {}", lhs, rhs) + })?, + (ArithmeticValue::Float(lhs), ArithmeticValue::Float(rhs)) => { + let diff = lhs - rhs; + if diff.is_finite() { + ArithmeticValue::Float(diff) + } else { + return Err(anyhow::anyhow!("Float overflow: {} - {}", lhs, rhs)); + } + } + (ArithmeticValue::Integer(lhs), ArithmeticValue::Float(rhs)) => { + let diff = *lhs as f64 - rhs; + if diff.is_finite() { + ArithmeticValue::Float(diff) + } else { + return Err(anyhow::anyhow!("Float overflow: {} - {}", lhs, rhs)); + } + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Integer(rhs)) => { + let diff = lhs - *rhs as f64; + if diff.is_finite() { + ArithmeticValue::Float(diff) + } else { + return Err(anyhow::anyhow!("Float overflow: {} - {}", lhs, rhs)); + } + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_mul( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => lhs + .checked_mul(*rhs) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} * {}", lhs, rhs) + })?, + (ArithmeticValue::Float(lhs), ArithmeticValue::Float(rhs)) => { + let product = lhs * rhs; + if product.is_finite() { + ArithmeticValue::Float(product) + } else { + return Err(anyhow::anyhow!("Float overflow: {} * {}", lhs, rhs)); + } + } + (ArithmeticValue::Integer(lhs), ArithmeticValue::Float(rhs)) + | (ArithmeticValue::Float(rhs), ArithmeticValue::Integer(lhs)) => { + let product = *lhs as f64 * rhs; + if product.is_finite() { + ArithmeticValue::Float(product) + } else { + return Err(anyhow::anyhow!("Float overflow: {} * {}", lhs, rhs)); + } + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_div( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs == 0 { + return Err(anyhow::anyhow!("Division by zero: {} / {}", lhs, rhs)); + } + lhs + .checked_div(*rhs) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} / {}", lhs, rhs) + })? + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Float(rhs)) => { + if *rhs == 0.0 { + return Err(anyhow::anyhow!("Division by zero: {} / {}", lhs, rhs)); + } + let quotient = lhs / rhs; + if quotient.is_finite() { + ArithmeticValue::Float(quotient) + } else { + return Err(anyhow::anyhow!("Float overflow: {} / {}", lhs, rhs)); + } + } + (ArithmeticValue::Integer(lhs), ArithmeticValue::Float(rhs)) => { + if *rhs == 0.0 { + return Err(anyhow::anyhow!("Division by zero: {} / {}", lhs, rhs)); + } + let quotient = *lhs as f64 / rhs; + if quotient.is_finite() { + ArithmeticValue::Float(quotient) + } else { + return Err(anyhow::anyhow!("Float overflow: {} / {}", lhs, rhs)); + } + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs == 0 { + return Err(anyhow::anyhow!("Division by zero: {} / {}", lhs, rhs)); + } + let quotient = lhs / *rhs as f64; + if quotient.is_finite() { + ArithmeticValue::Float(quotient) + } else { + return Err(anyhow::anyhow!("Float overflow: {} / {}", lhs, rhs)); + } + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_rem( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs == 0 { + return Err(anyhow::anyhow!("Modulo by zero: {} % {}", lhs, rhs)); + } + lhs + .checked_rem(*rhs) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} % {}", lhs, rhs) + })? + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Float(rhs)) => { + if *rhs == 0.0 { + return Err(anyhow::anyhow!("Modulo by zero: {} % {}", lhs, rhs)); + } + let remainder = lhs % rhs; + if remainder.is_finite() { + ArithmeticValue::Float(remainder) + } else { + return Err(anyhow::anyhow!("Float overflow: {} % {}", lhs, rhs)); + } + } + (ArithmeticValue::Integer(lhs), ArithmeticValue::Float(rhs)) => { + if *rhs == 0.0 { + return Err(anyhow::anyhow!("Modulo by zero: {} % {}", lhs, rhs)); + } + let remainder = *lhs as f64 % rhs; + if remainder.is_finite() { + ArithmeticValue::Float(remainder) + } else { + return Err(anyhow::anyhow!("Float overflow: {} % {}", lhs, rhs)); + } + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs == 0 { + return Err(anyhow::anyhow!("Modulo by zero: {} % {}", lhs, rhs)); + } + let remainder = lhs % *rhs as f64; + if remainder.is_finite() { + ArithmeticValue::Float(remainder) + } else { + return Err(anyhow::anyhow!("Float overflow: {} % {}", lhs, rhs)); + } + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_pow( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs < 0 { + let result = (*lhs as f64).powf(*rhs as f64); + if result.is_finite() { + ArithmeticValue::Float(result) + } else { + return Err(anyhow::anyhow!("Float overflow: {} ** {}", lhs, rhs)); + } + } else { + lhs + .checked_pow(*rhs as u32) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} ** {}", lhs, rhs) + })? + } + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Float(rhs)) => { + let result = lhs.powf(*rhs); + if result.is_finite() { + ArithmeticValue::Float(result) + } else { + return Err(anyhow::anyhow!("Float overflow: {} ** {}", lhs, rhs)); + } + } + (ArithmeticValue::Integer(lhs), ArithmeticValue::Float(rhs)) => { + let result = (*lhs as f64).powf(*rhs); + if result.is_finite() { + ArithmeticValue::Float(result) + } else { + return Err(anyhow::anyhow!("Float overflow: {} ** {}", lhs, rhs)); + } + } + (ArithmeticValue::Float(lhs), ArithmeticValue::Integer(rhs)) => { + let result = lhs.powf(*rhs as f64); + if result.is_finite() { + ArithmeticValue::Float(result) + } else { + return Err(anyhow::anyhow!("Float overflow: {} ** {}", lhs, rhs)); + } + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_neg(&self) -> Result { + let result = match &self.value { + ArithmeticValue::Integer(val) => val + .checked_neg() + .map(ArithmeticValue::Integer) + .ok_or_else(|| anyhow::anyhow!("Integer overflow: -{}", val))?, + ArithmeticValue::Float(val) => { + let result = -val; + if result.is_finite() { + ArithmeticValue::Float(result) + } else { + return Err(anyhow::anyhow!("Float overflow: -{}", val)); + } + } + }; + + Ok(ArithmeticResult { + value: result, + changes: self.changes.clone(), + }) + } + + pub fn checked_not(&self) -> Result { + let result = match &self.value { + ArithmeticValue::Integer(val) => ArithmeticValue::Integer(!val), + ArithmeticValue::Float(_) => { + return Err(anyhow::anyhow!( + "Invalid arithmetic result type for bitwise NOT: {}", + self + )) + } + }; + + Ok(ArithmeticResult { + value: result, + changes: self.changes.clone(), + }) + } + + pub fn checked_shl( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs < 0 { + return Err(anyhow::anyhow!( + "Negative shift amount: {} << {}", + lhs, + rhs + )); + } + lhs + .checked_shl(*rhs as u32) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer overflow: {} << {}", lhs, rhs) + })? + } + _ => { + return Err(anyhow::anyhow!( + "Invalid arithmetic result types for left shift: {} << {}", + self, + other + )) + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_shr( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + if *rhs < 0 { + return Err(anyhow::anyhow!( + "Negative shift amount: {} >> {}", + lhs, + rhs + )); + } + lhs + .checked_shr(*rhs as u32) + .map(ArithmeticValue::Integer) + .ok_or_else(|| { + anyhow::anyhow!("Integer underflow: {} >> {}", lhs, rhs) + })? + } + _ => { + return Err(anyhow::anyhow!( + "Invalid arithmetic result types for right shift: {} >> {}", + self, + other + )) + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_and( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + ArithmeticValue::Integer(lhs & rhs) + } + _ => { + return Err(anyhow::anyhow!( + "Invalid arithmetic result types for bitwise AND: {} & {}", + self, + other + )) + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_or( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + ArithmeticValue::Integer(lhs | rhs) + } + _ => { + return Err(anyhow::anyhow!( + "Invalid arithmetic result types for bitwise OR: {} | {}", + self, + other + )) + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn checked_xor( + &self, + other: &ArithmeticResult, + ) -> Result { + let result = match (&self.value, &other.value) { + (ArithmeticValue::Integer(lhs), ArithmeticValue::Integer(rhs)) => { + ArithmeticValue::Integer(lhs ^ rhs) + } + _ => { + return Err(anyhow::anyhow!( + "Invalid arithmetic result types for bitwise XOR: {} ^ {}", + self, + other + )) + } + }; + + let mut changes = self.changes.clone(); + changes.extend(other.changes.clone()); + + Ok(ArithmeticResult { + value: result, + changes, + }) + } + + pub fn with_changes(mut self, changes: Vec) -> Self { + self.changes = changes; + self + } +} + +impl From for ArithmeticResult { + fn from(value: String) -> Self { + if let Ok(int_val) = value.parse::() { + ArithmeticResult::new(ArithmeticValue::Integer(int_val)) + } else if let Ok(float_val) = value.parse::() { + ArithmeticResult::new(ArithmeticValue::Float(float_val)) + } else { + panic!("Invalid arithmetic result: {}", value); + } + } +} + +impl FromStr for ArithmeticResult { + type Err = String; + + fn from_str(s: &str) -> Result { + Ok(s.to_string().into()) + } +} + +pub struct WordEvalResult { + pub value: Vec, + pub changes: Vec, +} + +impl WordEvalResult { + pub fn new(value: Vec, changes: Vec) -> Self { + WordEvalResult { value, changes } + } + + pub fn extend(&mut self, other: WordEvalResult) { + self.value.extend(other.value); + self.changes.extend(other.changes); + } + + pub fn join(&self, sep: &str) -> String { + self.value.join(sep) + } +} diff --git a/crates/tests/src/lib.rs b/crates/tests/src/lib.rs index ec20a31..7eb8367 100644 --- a/crates/tests/src/lib.rs +++ b/crates/tests/src/lib.rs @@ -813,6 +813,63 @@ async fn which() { .await; } +#[tokio::test] +async fn arithmetic() { + TestBuilder::new() + .command("echo $((1 + 2 * 3 + (4 / 5)))") + .assert_stdout("7\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, b=2))") + .assert_stdout("2\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((a=1, b=2, a+b))") + .assert_stdout("3\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((1 + 2))") + .assert_stdout("3\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((5 * 4))") + .assert_stdout("20\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((10 / 3))") + .assert_stdout("3\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((2 ** 3))") + .assert_stdout("8\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((2 << 3))") + .assert_stdout("16\n") + .run() + .await; + + TestBuilder::new() + .command("echo $((2 << 3))") + .assert_stdout("16\n") + .run() + .await; +} + #[cfg(test)] fn no_such_file_error_text() -> &'static str { if cfg!(windows) { diff --git a/scripts/arithmetic.sh b/scripts/arithmetic.sh new file mode 100644 index 0000000..50c0752 --- /dev/null +++ b/scripts/arithmetic.sh @@ -0,0 +1 @@ +echo $((2 ** 3)) \ No newline at end of file