Skip to content

Commit

Permalink
Make CLI arguments accessible in zxpy programs (#52)
Browse files Browse the repository at this point in the history
* Fix sys.argv passed to tests

* Improve test assertion

* Add $ arg support

* Normalize /bin/sh as first argument of script_args

* Add type def
  • Loading branch information
tusharsadhwani committed Feb 21, 2023
1 parent 0d999ad commit 5b0db6b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 5 deletions.
8 changes: 6 additions & 2 deletions tests/test_files/argv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import sys

assert len(sys.argv) == 1
assert sys.argv[0].endswith("argv.py")
assert len(sys.argv) == 3
assert sys.argv[1] == "foobar"
assert sys.argv[2] == "baz"

out = ~"echo $1 and $2"
assert out == "foobar and baz\n"
6 changes: 6 additions & 0 deletions tests/test_files/injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
x = ~"uname -p"
print(x in ("arm\n", "x86_64\n"))

command = "uname -p"
_, _, rc = ~f"{command}" # This doesn't work
print(rc)
60 changes: 59 additions & 1 deletion tests/zxpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def test_prints(capsys: pytest.CaptureFixture[str]) -> None:

def test_argv() -> None:
test_file = "./tests/test_files/argv.py"
subprocess.run(["zxpy", test_file])
returncode = subprocess.check_call(["zxpy", test_file, "--", "foobar", "baz"])
assert returncode == 0


def test_raise() -> None:
Expand Down Expand Up @@ -114,3 +115,60 @@ def f(n):
assert stderr == b'\n'
outlines = [line for line in stdout.decode().splitlines() if line.startswith('>>>')]
assert outlines == [">>> hi", ">>> 10", ">>> ... ... ... ... >>> 8", ">>> "]


@pytest.mark.parametrize(
("input", "index", "output"),
(
("echo 'hello world' hi", 0, False),
("echo 'hello world' hi", 4, False),
("echo 'hello world' hi", 5, True),
("echo 'hello world' hi", 6, True),
("echo 'hello world' hi", 16, True),
("echo 'hello world' hi", 17, True),
("echo 'hello world' hi", 18, False),
("echo 'hello world' hi", 21, False),
('abc "def\'ghi" jkl \'mnop\'', 5, False),
('abc "def\'ghi" jkl \'mnop\'', 8, False),
('abc "def\'ghi" jkl \'mnop\'', 10, False),
('abc "def\'ghi" jkl \'mnop\'', 14, False),
('abc "def\'ghi" jkl \'mnop\'', 17, False),
('abc "def\'ghi" jkl \'mnop\'', 18, True),
('abc "def\'ghi" jkl \'mnop\'', 21, True),
("'a' 'b' c 'de' 'fg' h", 1, True),
("'a' 'b' c 'de' 'fg' h", 3, False),
("'a' 'b' c 'de' 'fg' h", 6, True),
("'a' 'b' c 'de' 'fg' h", 10, False),
("'a' 'b' c 'de' 'fg' h", 14, True),
("'a' 'b' c 'de' 'fg' h", 16, False),
("'a' 'b' c 'de' 'fg' h", 19, True),
("'a' 'b' c 'de' 'fg' h", 22, False),
("a \"b'c'd'e\" '\"' '\"abc'", 1, False),
("a \"b'c'd'e\" '\"' '\"abc'", 2, False),
("a \"b'c'd'e\" '\"' '\"abc'", 4, False),
("a \"b'c'd'e\" '\"' '\"abc'", 6, False),
("a \"b'c'd'e\" '\"' '\"abc'", 8, False),
("a \"b'c'd'e\" '\"' '\"abc'", 10, False),
("a \"b'c'd'e\" '\"' '\"abc'", 12, True),
("a \"b'c'd'e\" '\"' '\"abc'", 13, True),
("a \"b'c'd'e\" '\"' '\"abc'", 14, True),
("a \"b'c'd'e\" '\"' '\"abc'", 15, False),
("a \"b'c'd'e\" '\"' '\"abc'", 16, True),
("a \"b'c'd'e\" '\"' '\"abc'", 17, True),
("a \"b'c'd'e\" '\"' '\"abc'", 18, True),
("a \"b'c'd'e\" '\"' '\"abc'", 20, True),
),
)
def test_is_inside_single_quotes(input, index, output) -> None:
assert zx.is_inside_single_quotes(input, index) == output


def test_shell_injection():
"""Test injecting commands or shell args like `$0` into shell strings."""
file = "./tests/test_files/injection.py"
output = subprocess.check_output(["zxpy", file, "--", "abc"]).decode()
assert output == (
"True\n" # uname -p worked as a string
"127\n" # uname -p inside f-string got quoted
)
# TODO: $1 injection test
78 changes: 76 additions & 2 deletions zx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import codecs
import contextlib
import inspect
import pipes
import re
import shlex
import subprocess
import sys
Expand Down Expand Up @@ -64,10 +66,21 @@ def cli() -> None:
)
parser.add_argument('filename', help='Name of file to run', nargs='?')

# Everything passed after a `--` is arguments to be used by the script itself.
script_args = ['/bin/sh']
try:
separator_index = sys.argv.index('--')
script_args.extend(sys.argv[separator_index + 1 :])
# Remove everything after `--` so that argparse passes
sys.argv = sys.argv[:separator_index]
except ValueError:
# `--` not present in command, so no extra script args
pass

args = parser.parse_args(namespace=ZxpyArgs())

# Remove zxpy executable from argv
del sys.argv[0]
# Once arg parsing is done, replace argv with script args
sys.argv = script_args

if args.filename is None:
setup_zxpy_repl()
Expand All @@ -91,9 +104,70 @@ def cli() -> None:
install()


def is_inside_single_quotes(string: str, index: int) -> bool:
"""Returns True if the given index is inside single quotes in a shell command."""
quote_index = string.find("'")
if quote_index == -1:
# No single quotes
return False

if index < quote_index:
# We're before the start of the single quotes
return False

double_quote_index = string.find('"')
if double_quote_index >= 0 and double_quote_index < quote_index:
next_double_quote = string.find('"', double_quote_index + 1)
if next_double_quote == -1:
# Double quote opened but never closed
return False

# Single quotes didn't start and we passed the index
if next_double_quote >= index:
return False

# Ignore all single quotes inside double quotes.
index -= next_double_quote + 1
rest = string[next_double_quote + 1 :]
return is_inside_single_quotes(rest, index)

next_quote = string.find("'", quote_index + 1)
if next_quote >= index:
# We're inside single quotes
return True

index -= next_quote + 1
rest = string[next_quote + 1 :]
return is_inside_single_quotes(rest, index)


@contextlib.contextmanager
def create_shell_process(command: str) -> Generator[IO[bytes], None, None]:
"""Creates a shell process, yielding its stdout to read data from."""
# shell argument support, i.e. $0, $1 etc.

dollar_indices = [index for index, char in enumerate(command) if char == '$']
for dollar_index in reversed(dollar_indices):
if (
dollar_index >= 0
and dollar_index + 1 < len(command)
and command[dollar_index + 1].isdigit()
and not is_inside_single_quotes(command, dollar_index)
):
end_index = dollar_index + 1
while end_index + 1 < len(command) and command[end_index + 1].isdigit():
end_index += 1

number = int(command[dollar_index + 1 : end_index + 1])

# Get argument number from sys.argv
if number < len(sys.argv):
replacement = sys.argv[number]
else:
replacement = ""

command = command[:dollar_index] + replacement + command[end_index + 1 :]

process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
Expand Down

0 comments on commit 5b0db6b

Please sign in to comment.