forked from mongodb/mongo-python-driver
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
154 lines (134 loc) · 4.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations
import argparse
import dataclasses
import logging
import os
import shlex
import subprocess
import sys
from pathlib import Path
from typing import Any
HERE = Path(__file__).absolute().parent
ROOT = HERE.parent.parent
DRIVERS_TOOLS = os.environ.get("DRIVERS_TOOLS", "").replace(os.sep, "/")
TMP_DRIVER_FILE = "/tmp/mongo-python-driver.tgz" # noqa: S108
LOGGER = logging.getLogger("test")
logging.basicConfig(level=logging.INFO, format="%(levelname)-8s %(message)s")
ENV_FILE = HERE / "test-env.sh"
PLATFORM = "windows" if os.name == "nt" else sys.platform.lower()
@dataclasses.dataclass
class Distro:
name: str
version_id: str
arch: str
# Map the test name to a test suite.
TEST_SUITE_MAP = {
"atlas_connect": "atlas_connect",
"auth_aws": "auth_aws",
"auth_oidc": "auth_oidc",
"data_lake": "data_lake",
"default": "",
"default_async": "default_async",
"default_sync": "default",
"encryption": "encryption",
"enterprise_auth": "auth",
"index_management": "index_management",
"kms": "kms",
"load_balancer": "load_balancer",
"mockupdb": "mockupdb",
"pyopenssl": "",
"ocsp": "ocsp",
"perf": "perf",
"serverless": "",
}
# Tests that require a sub test suite.
SUB_TEST_REQUIRED = ["auth_aws", "auth_oidc", "kms", "mod_wsgi"]
EXTRA_TESTS = ["mod_wsgi"]
def get_test_options(
description, require_sub_test_name=True, allow_extra_opts=False
) -> tuple[argparse.Namespace, list[str]]:
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.RawDescriptionHelpFormatter
)
if require_sub_test_name:
parser.add_argument(
"test_name",
choices=sorted(list(TEST_SUITE_MAP) + EXTRA_TESTS),
nargs="?",
default="default",
help="The optional name of the test suite to set up, typically the same name as a pytest marker.",
)
parser.add_argument(
"sub_test_name", nargs="?", help="The optional sub test name, for example 'azure'."
)
else:
parser.add_argument(
"test_name",
choices=sorted(TEST_SUITE_MAP),
nargs="?",
default="default",
help="The optional name of the test suite to be run, which informs the server configuration.",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Whether to log at the DEBUG level"
)
parser.add_argument(
"--quiet", "-q", action="store_true", help="Whether to log at the WARNING level"
)
parser.add_argument("--auth", action="store_true", help="Whether to add authentication")
parser.add_argument("--ssl", action="store_true", help="Whether to add TLS configuration")
# Get the options.
if not allow_extra_opts:
opts, extra_opts = parser.parse_args(), []
else:
opts, extra_opts = parser.parse_known_args()
if opts.verbose:
LOGGER.setLevel(logging.DEBUG)
elif opts.quiet:
LOGGER.setLevel(logging.WARNING)
# Handle validation and environment variable overrides.
test_name = opts.test_name
sub_test_name = opts.sub_test_name if require_sub_test_name else ""
if require_sub_test_name and test_name in SUB_TEST_REQUIRED and not sub_test_name:
raise ValueError(f"Test '{test_name}' requires a sub_test_name")
if "auth" in test_name or os.environ.get("AUTH") == "auth":
opts.auth = True
# 'auth_aws ecs' shouldn't have extra auth set.
if test_name == "auth_aws" and sub_test_name == "ecs":
opts.auth = False
if os.environ.get("SSL") == "ssl":
opts.ssl = True
return opts, extra_opts
def read_env(path: Path | str) -> dict[str, Any]:
config = dict()
with Path(path).open() as fid:
for line in fid.readlines():
if "=" not in line:
continue
name, _, value = line.strip().partition("=")
if value.startswith(('"', "'")):
value = value[1:-1]
name = name.replace("export ", "")
config[name] = value
return config
def write_env(name: str, value: Any = "1") -> None:
with ENV_FILE.open("a", newline="\n") as fid:
# Remove any existing quote chars.
value = str(value).replace('"', "")
fid.write(f'export {name}="{value}"\n')
def run_command(cmd: str | list[str], **kwargs: Any) -> None:
if isinstance(cmd, list):
cmd = " ".join(cmd)
LOGGER.info("Running command '%s'...", cmd)
kwargs.setdefault("check", True)
try:
subprocess.run(shlex.split(cmd), **kwargs) # noqa: PLW1510, S603
except subprocess.CalledProcessError as e:
LOGGER.error(e.output)
LOGGER.error(str(e))
sys.exit(e.returncode)
LOGGER.info("Running command '%s'... done.", cmd)
def create_archive() -> None:
run_command("git add .", cwd=ROOT)
run_command('git commit -m "add files"', check=False, cwd=ROOT)
run_command(f"git archive -o {TMP_DRIVER_FILE} HEAD", cwd=ROOT)