diff --git a/patchwork/app.py b/patchwork/app.py index 4149959a6..1b737519d 100644 --- a/patchwork/app.py +++ b/patchwork/app.py @@ -59,6 +59,7 @@ def list_option_callback(ctx: click.Context, param: click.Parameter, value: str def find_patchflow(possible_module_paths: Iterable[str], patchflow: str) -> Any | None: + allowed_modules = {"allowed_module1", "allowed_module2", "allowed_module3"} for module_path in possible_module_paths: try: spec = importlib.util.spec_from_file_location("custom_module", module_path) @@ -72,9 +73,12 @@ def find_patchflow(possible_module_paths: Iterable[str], patchflow: str) -> Any logger.debug(f"Patchflow {patchflow} not found as a file/directory in {module_path}") try: - module = importlib.import_module(module_path) - logger.info(f"Patchflow {patchflow} loaded from {module_path}") - return getattr(module, patchflow) + if module_path in allowed_modules: + module = importlib.import_module(module_path) + logger.info(f"Patchflow {patchflow} loaded from {module_path}") + return getattr(module, patchflow) + else: + logger.debug(f"Module path {module_path} is not in the allowed modules list") except ModuleNotFoundError: logger.debug(f"Patchflow {patchflow} not found as a module in {module_path}") except AttributeError: diff --git a/patchwork/common/tools/bash_tool.py b/patchwork/common/tools/bash_tool.py index 8440f179a..82735076e 100644 --- a/patchwork/common/tools/bash_tool.py +++ b/patchwork/common/tools/bash_tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +import shlex import subprocess from pathlib import Path @@ -44,8 +45,9 @@ def execute( return f"Error: `command` parameter must be set and cannot be empty" try: + args = shlex.split(command) result = subprocess.run( - command, shell=True, cwd=self.path, capture_output=True, text=True, timeout=60 # Add timeout for safety + args, shell=False, cwd=self.path, capture_output=True, text=True, timeout=60 # Add timeout for safety ) return result.stdout if result.returncode == 0 else f"Error: {result.stderr}" except subprocess.TimeoutExpired: diff --git a/patchwork/common/tools/csvkit_tool.py b/patchwork/common/tools/csvkit_tool.py index a1ef8dc59..b56fca712 100644 --- a/patchwork/common/tools/csvkit_tool.py +++ b/patchwork/common/tools/csvkit_tool.py @@ -118,9 +118,8 @@ def execute(self, files: list[str], query: str) -> str: if db_path.is_file(): with sqlite3.connect(str(db_path)) as conn: for file in files: - res = conn.execute( - f"SELECT 1 from {file.removesuffix('.csv')}", - ) + table_name = file.removesuffix('.csv') + res = conn.execute("SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) if res.fetchone() is None: files_to_insert.append(file) else: diff --git a/patchwork/common/utils/dependency.py b/patchwork/common/utils/dependency.py index 27b89bfed..c620fded8 100644 --- a/patchwork/common/utils/dependency.py +++ b/patchwork/common/utils/dependency.py @@ -6,9 +6,12 @@ "notification": ["slack_sdk"], } +__ALLOWED_MODULES = set(module for group in __DEPENDENCY_GROUPS.values() for module in group) @lru_cache(maxsize=None) def import_with_dependency_group(name): + if name not in __ALLOWED_MODULES: + raise ImportError(f"Module {name} is not allowed to be imported.") try: return importlib.import_module(name) except ImportError: diff --git a/patchwork/common/utils/step_typing.py b/patchwork/common/utils/step_typing.py index d349f7fc1..05a0ea722 100644 --- a/patchwork/common/utils/step_typing.py +++ b/patchwork/common/utils/step_typing.py @@ -106,8 +106,12 @@ def validate_step_type_config_with_inputs( def validate_step_with_inputs(input_keys: Set[str], step: Type[Step]) -> Tuple[Set[str], Dict[str, str]]: + module_whitelist = {"allowed.module1", "allowed.module2"} # Add the appropriate valid modules here. module_path, _, _ = step.__module__.rpartition(".") step_name = step.__name__ + if f"{module_path}.typed" not in module_whitelist: + raise ValueError(f"Importing from {module_path}.typed is not allowed.") + type_module = importlib.import_module(f"{module_path}.typed") step_input_model = getattr(type_module, f"{step_name}Inputs", __NOT_GIVEN) step_output_model = getattr(type_module, f"{step_name}Outputs", __NOT_GIVEN) diff --git a/patchwork/steps/CallShell/CallShell.py b/patchwork/steps/CallShell/CallShell.py index 98ee55a74..e907cf126 100644 --- a/patchwork/steps/CallShell/CallShell.py +++ b/patchwork/steps/CallShell/CallShell.py @@ -46,7 +46,7 @@ def __parse_env_text(env_text: str) -> dict[str, str]: return env def run(self) -> dict: - p = subprocess.run(self.script, shell=True, capture_output=True, text=True, cwd=self.working_dir, env=self.env) + p = subprocess.run(shlex.split(self.script), shell=False, capture_output=True, text=True, cwd=self.working_dir, env=self.env) try: p.check_returncode() except subprocess.CalledProcessError as e: @@ -57,3 +57,4 @@ def run(self) -> dict: logger.info(f"stdout: \n{p.stdout}") logger.info(f"stderr:\n{p.stderr}") return dict(stdout_output=p.stdout, stderr_output=p.stderr) +