diff --git a/src/fromager/__main__.py b/src/fromager/__main__.py index 4738f8fa..76d141b7 100644 --- a/src/fromager/__main__.py +++ b/src/fromager/__main__.py @@ -134,7 +134,8 @@ "-c", "--constraints-file", type=str, - help="location of the constraints file", + multiple=True, + help="location of constraint file(s), may be repeated", ) @click.option( "--cleanup/--no-cleanup", @@ -177,7 +178,7 @@ def main( patches_dir: pathlib.Path, settings_file: pathlib.Path, settings_dir: pathlib.Path, - constraints_file: str, + constraints_file: tuple[str, ...], cleanup: bool, variant: str, jobs: int | None, @@ -247,7 +248,7 @@ def main( logger.info(f"variant: {variant}") logger.info(f"patches dir: {patches_dir}") logger.info(f"maximum concurrent jobs: {jobs}") - logger.info(f"constraints file: {constraints_file}") + logger.info(f"constraints file(s): {constraints_file}") logger.info(f"network isolation: {network_isolation}") if build_wheel_server_url: logger.info(f"external build wheel server: {build_wheel_server_url}") diff --git a/src/fromager/context.py b/src/fromager/context.py index 1bc543ed..7e9710ec 100644 --- a/src/fromager/context.py +++ b/src/fromager/context.py @@ -19,7 +19,6 @@ dependency_graph, external_commands, packagesettings, - request_session, ) if typing.TYPE_CHECKING: @@ -36,7 +35,7 @@ class WorkContext: def __init__( self, active_settings: packagesettings.Settings | None, - constraints_file: str | None, + constraints_file: str | tuple[str, ...] | None, patches_dir: pathlib.Path, sdists_repo: pathlib.Path, wheels_repo: pathlib.Path, @@ -59,13 +58,16 @@ def __init__( max_jobs=max_jobs, ) self.settings = active_settings - self.input_constraints_uri: str | None self.constraints = constraints.Constraints() - if constraints_file is not None: - self.input_constraints_uri = constraints_file - self.constraints.load_constraints_file(constraints_file) - else: - self.input_constraints_uri = None + self.input_constraints_uris: list[str] = [] + if constraints_file: + if isinstance(constraints_file, str): + files: tuple[str, ...] = (constraints_file,) + else: + files = constraints_file + for cf in files: + self.input_constraints_uris.append(cf) + self.constraints.load_constraints_file(cf) self.sdists_repo = pathlib.Path(sdists_repo).resolve() self.sdists_downloads = self.sdists_repo / "downloads" self.sdists_builds = self.sdists_repo / "builds" @@ -135,16 +137,16 @@ def pip_wheel_server_args(self) -> list[str]: @property def pip_constraint_args(self) -> list[str]: - if not self.input_constraints_uri: + if not self.input_constraints_uris: return [] - if self.input_constraints_uri.startswith(("https://", "http://", "file://")): - path_to_constraints_file = self.work_dir / "input-constraints.txt" - if not path_to_constraints_file.exists(): - response = request_session.session.get(self.input_constraints_uri) - path_to_constraints_file.write_text(response.text) - else: - path_to_constraints_file = pathlib.Path(self.input_constraints_uri) + path_to_constraints_file = self.work_dir / "input-constraints.txt" + lines: list[str] = [] + for constraint_name in self.constraints: + req = self.constraints.get_constraint(constraint_name) + if req is not None: + lines.append(f"{req}\n") + path_to_constraints_file.write_text("".join(lines), encoding="utf-8") path_to_constraints_file = path_to_constraints_file.absolute() return ["--constraint", os.fspath(path_to_constraints_file)] diff --git a/tests/test_cli.py b/tests/test_cli.py index d98845b7..add05aeb 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -96,6 +96,34 @@ def test_output_dir_overridden_by_explicit_flags( assert not (out / "sdists-repo").exists() +def test_multiple_constraints_files( + tmp_path: pathlib.Path, cli_runner: CliRunner +) -> None: + """Multiple -c flags are accepted and constraints are merged.""" + file1 = tmp_path / "base.txt" + file1.write_text("numpy>=1.24\n") + file2 = tmp_path / "extra.txt" + file2.write_text("numpy<2.0\n") + + out = tmp_path / "output" + out.mkdir() + + result = cli_runner.invoke( + fromager, + [ + "-O", + str(out), + "-c", + str(file1), + "-c", + str(file2), + "canonicalize", + "some-package", + ], + ) + assert result.exit_code == 0, result.output + + KNOWN_COMMANDS: set[str] = { "bootstrap", "bootstrap-parallel", diff --git a/tests/test_context.py b/tests/test_context.py index c4b19339..fe2da3db 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -10,7 +10,7 @@ def _make_context( tmp_path: pathlib.Path, - constraints_file: str | None = None, + constraints_file: str | tuple[str, ...] | None = None, wheel_server_url: str = "", cleanup: bool = True, ) -> context.WorkContext: @@ -43,16 +43,40 @@ def _all_setup_dirs(ctx: context.WorkContext) -> list[pathlib.Path]: def test_pip_constraints_args(tmp_path: pathlib.Path) -> None: constraints_file = tmp_path / "constraints.txt" - constraints_file.write_text("\n") # the file has to exist + constraints_file.write_text("numpy>=1.24\n") ctx = _make_context(tmp_path, constraints_file=str(constraints_file)) ctx.setup() - assert ["--constraint", os.fspath(constraints_file)] == ctx.pip_constraint_args + merged_path = ctx.work_dir / "input-constraints.txt" + assert ["--constraint", os.fspath(merged_path)] == ctx.pip_constraint_args + assert merged_path.exists() + assert "numpy>=1.24" in merged_path.read_text() ctx = _make_context(tmp_path) ctx.setup() assert [] == ctx.pip_constraint_args +def test_pip_constraints_args_multiple_files(tmp_path: pathlib.Path) -> None: + file1 = tmp_path / "base.txt" + file1.write_text("numpy>=1.24\n") + file2 = tmp_path / "security.txt" + file2.write_text("numpy<2.0\nrequests>=2.28\n") + + ctx = _make_context(tmp_path, constraints_file=(str(file1), str(file2))) + ctx.setup() + merged_path = ctx.work_dir / "input-constraints.txt" + assert ["--constraint", os.fspath(merged_path)] == ctx.pip_constraint_args + content = merged_path.read_text() + lines = [line for line in content.splitlines() if line.strip()] + numpy_lines = [line for line in lines if line.startswith("numpy")] + assert len(numpy_lines) == 1 + assert ">=1.24" in numpy_lines[0] + assert "<2.0" in numpy_lines[0] + requests_lines = [line for line in lines if line.startswith("requests")] + assert len(requests_lines) == 1 + assert ">=2.28" in requests_lines[0] + + def test_setup_creates_directories(tmp_path: pathlib.Path) -> None: ctx = _make_context(tmp_path) ctx.setup()