Permalink
Browse files

allow specifying import paths, fixes #1 (use -p .)

  • Loading branch information...
1 parent c16baf9 commit 1a76a06bcac9f8af4e55f1da851239962074b1bb @timbertson committed Jun 11, 2012
Showing with 48 additions and 0 deletions.
  1. +5 −0 piep/main.py
  2. +22 −0 test/test_env.py
  3. +21 −0 test/test_helper.py
View
@@ -53,6 +53,7 @@ def run(argv=None):
p.add_option('-m', '--import', action='append', dest='imports', default=[], metavar='MODULE', help='add a module to global scope (may be given multiple times)')
p.add_option('-f', '--file', action='append', dest='files', default=[], metavar='FILE', help='add another input stream (available as f[n])')
p.add_option('-i', '--input', dest='input', help='use a named file (instead of stdin)')
+ p.add_option('-p', '--path', action='append', dest='import_paths', default=[], help='add a location to the import path (the same as $PYTHONPATH / sys.path)')
opts, args = p.parse_args(argv)
DEBUG = opts.debug
@@ -91,6 +92,10 @@ def make_stream(f):
pp = make_stream(input_file)
globs = builtins.copy()
globs['pp'] = pp
+ for path in opts.import_paths:
+ path = os.path.abspath(path)
+ if path not in sys.path:
+ sys.path.insert(0, path)
for import_mod in opts.imports:
import_node = ast.Import(names=[ast.alias(name=import_mod, asname=None)])
code = compile(ast.fix_missing_locations(ast.Module(body=[import_node])), 'import %s' % (import_mod,), 'exec')
View
@@ -0,0 +1,22 @@
+from test.test_helper import run, temp_cwd
+from unittest import TestCase
+import subprocess
+
+
+class TestModuleImporting(TestCase):
+ def test_modules_are_not_importable_from_cwd_by_default(self):
+ # for the same reason $PATH does not include
+ # "." - it could be an attack vector
+ with temp_cwd():
+ with open("mymod.py", 'w') as f:
+ f.write("def up(s): return s.upper()")
+ self.assertRaises(ImportError,
+ lambda: run('-m', 'mymod', 'mymod.up(p)', ['a']))
+
+ def test_modules_can_be_imported_from_cwd(self):
+ with temp_cwd():
+ with open("mymod.py", 'w') as f:
+ f.write("def up(s): return s.upper()")
+ self.assertEqual(
+ run('-p', '.', '-m', 'mymod', 'mymod.up(p)', ['a']), ['A'])
+
View
@@ -1,5 +1,9 @@
+import os
import sys
import itertools
+import contextlib
+import shutil
+import tempfile
from piep import main
@@ -13,3 +17,20 @@ def run(*args):
finally:
sys.stdin = old_stdin
+@contextlib.contextmanager
+def cwd(path):
+ old_cwd = os.getcwd()
+ os.chdir(path)
+ try:
+ yield
+ finally:
+ os.chdir(old_cwd)
+
+@contextlib.contextmanager
+def temp_cwd():
+ path = tempfile.mkdtemp()
+ try:
+ with cwd(path):
+ yield
+ finally:
+ shutil.rmtree(path)

0 comments on commit 1a76a06

Please sign in to comment.