diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a426c30..fe83c05 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -150,7 +150,7 @@ jobs: GH_REPOSITORY: ${{ github.repository }} GH_LABELS: ${{ format('ci-storage-test-{0}-{1}', github.run_id, github.run_attempt) }} TZ: America/Los_Angeles - FORWARD_HOST: host.docker.internal + FORWARD_HOST: "host.docker.internal:42 host.docker.internal:4242" # Test the job with ci-storage-test tag which is initially queued, but then is # picked up by the ci-runner container booted in the previous job. In the end, diff --git a/.vscode/settings.json b/.vscode/settings.json index 35a9f65..41f88d4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -16,5 +16,9 @@ "strerror", "tmpfs", "topo" - ] + ], + "python.languageServer": "Default", + "cursorpyright.analysis.typeCheckingMode": "basic", + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true } diff --git a/docker/ci-runner/Dockerfile b/docker/ci-runner/Dockerfile index f06ff45..db93aed 100644 --- a/docker/ci-runner/Dockerfile +++ b/docker/ci-runner/Dockerfile @@ -26,7 +26,7 @@ RUN true \ # https://forums.docker.com/t/etc-init-d-docker-62-ulimit-error-setting-limit-invalid-argument-problem/139424 RUN true \ && install -m 0755 -d /etc/apt/keyrings \ - && curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \ + && curl -fsSL --retry 3 --retry-all-errors https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \ && chmod a+r /etc/apt/keyrings/docker.asc \ && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list \ && apt-get update -y \ @@ -54,8 +54,8 @@ RUN true \ aarch64|arm64) arch=linux-arm64 ;; \ *) echo >&2 "unsupported architecture: $arch"; exit 1 ;; \ esac \ - && runner_version=$(curl --silent "https://api.github.com/repos/actions/runner/releases/latest" | jq -r ".tag_name[1:]") \ - && curl --no-progress-meter -L https://github.com/actions/runner/releases/download/v$runner_version/actions-runner-$arch-$runner_version.tar.gz | tar xz \ + && runner_version=$(curl -fsSL --retry 3 --retry-all-errors "https://api.github.com/repos/actions/runner/releases/latest" | jq -r ".tag_name[1:]") \ + && curl -fsSL --retry 3 --retry-all-errors --no-progress-meter https://github.com/actions/runner/releases/download/v$runner_version/actions-runner-$arch-$runner_version.tar.gz | tar xz \ && date > .updated_at # Install OS dependencies needed by the action runner. diff --git a/docker/ci-runner/README.md b/docker/ci-runner/README.md index 47ef0e9..e156316 100644 --- a/docker/ci-runner/README.md +++ b/docker/ci-runner/README.md @@ -15,10 +15,15 @@ self-hosted runners as you want. An example scenario: - `GH_LABELS` (required): labels added to this runner, comma-separated - `TZ` (optional): timezone name - `FORWARD_HOST` (optional): some ports at localhost (provided in - FORWARD_PORTS) will be forwarded to this host + FORWARD_PORTS) will be forwarded to this host; can also be a + space-separated list of hosts, in which case the 1st host plays the role of + a primary and the rest are backups (1st backup server available receives + all traffic) - `FORWARD_PORTS` (optional): a space-delimited list of forwarded TCP or UDP ports; any port number may be suffixed with "/udp" to forward UDP, e.g. - "12345/udp" + "12345/udp"; if it's "12345/tcp-backup", then the primary and backup hosts + are flipped (i.e. the traffic is first sent to backup host and only then, + if it's not available, to the primary host) - `CI_STORAGE_HOST` (optional): the host which the initial ci-storage run will pull the data from; often times it is set to "127.0.0.1:10022" where 10022 is an example of SSH port forwarded via FORWARD_HOST/FORWARD_PORTS diff --git a/docker/ci-runner/guest/entrypoint.03-update.sh b/docker/ci-runner/guest/entrypoint.03-update.sh index 9362243..3810ed5 100644 --- a/docker/ci-runner/guest/entrypoint.03-update.sh +++ b/docker/ci-runner/guest/entrypoint.03-update.sh @@ -19,7 +19,7 @@ if [[ ! -f "$updated_at_file" || "$(find . -name "$updated_at_file" -mtime +21)" esac say "Getting the latest runner version using HEAD to avoid rate limiting (previously updated at $(cat $updated_at_file))..." - runner_location=$(curl --head -sS --fail https://github.com/actions/runner/releases/latest | sed 's/\r$//' | grep -i "location:") + runner_location=$(curl -fsSL --retry 3 --retry-all-errors --head https://github.com/actions/runner/releases/latest | sed 's/\r$//' | grep -i "location:") runner_version="${runner_location##*/tag/v}" if [[ "$runner_version" == *.*.* ]]; then @@ -32,7 +32,7 @@ if [[ ! -f "$updated_at_file" || "$(find . -name "$updated_at_file" -mtime +21)" if [[ ! -r "$path" ]]; then say "Downloading $url to $CACHE_DIR..." - curl --no-progress-meter -L "$url" > "$path.tmp" + curl -fsSL --retry 3 --retry-all-errors --no-progress-meter "$url" > "$path.tmp" mv -f "$path.tmp" "$path" else say "Using previously downloaded $path" diff --git a/docker/ci-runner/root/entrypoint.01-validate.sh b/docker/ci-runner/root/entrypoint.01-validate.sh index 1eedcf4..a62eee3 100644 --- a/docker/ci-runner/root/entrypoint.01-validate.sh +++ b/docker/ci-runner/root/entrypoint.01-validate.sh @@ -24,19 +24,19 @@ fi export TZ if [[ "${TZ:=}" != "" && ! "$TZ" =~ ^[-+_/a-zA-Z0-9]+$ ]]; then - say "If TZ is passed, it must be a valid TZ Idenfitier from https://en.wikipedia.org/wiki/List_of_tz_database_time_zones" + say "If TZ is passed, it must be a valid TZ Identifier from https://en.wikipedia.org/wiki/List_of_tz_database_time_zones" exit 1 fi export FORWARD_HOST -if [[ "${FORWARD_HOST:=}" != "" && ! "$FORWARD_HOST" =~ ^[-.[:alnum:]]+(:[0-9]+)?$ ]]; then - say "If FORWARD_HOST is passed, it must be a hostname." +if [[ "${FORWARD_HOST:=}" != "" && ! "$FORWARD_HOST" =~ ^([-.[:alnum:]]+(:[0-9]+)?[[:space:]]*)+$ ]]; then + say "If FORWARD_HOST is passed, it must be a hostname or a space-separated list of hostnames." exit 1 fi export FORWARD_PORTS -if [[ "${FORWARD_PORTS:=}" != "" && ! "$FORWARD_PORTS" =~ ^([[:space:]]*[0-9]+(/tcp|/udp)?[[:space:]]*)+$ ]]; then - echo 'If FORWARD_PORTS is passed, it must be in the form of (example): "123 456/udp 789/tcp".'; +if [[ "${FORWARD_PORTS:=}" != "" && ! "$FORWARD_PORTS" =~ ^([[:space:]]*[0-9]+(/tcp|/tcp-backup|/udp)?[[:space:]]*)+$ ]]; then + echo 'If FORWARD_PORTS is passed, it must be in the form of (example): "123 789/tcp 123/tcp-backup 456/udp".'; exit 1 fi diff --git a/docker/ci-runner/root/entrypoint.09-forward.sh b/docker/ci-runner/root/entrypoint.09-forward.sh index 2c8fb50..de66a0c 100644 --- a/docker/ci-runner/root/entrypoint.09-forward.sh +++ b/docker/ci-runner/root/entrypoint.09-forward.sh @@ -2,24 +2,51 @@ # # Sets up port forwarding to the storage host. # +# Format for each entry in FORWARD_PORTS: +# - 1234 (implies tcp) +# - 1234/udp +# - 1234/tcp +# - 1234/tcp-backup (flips primary server with backup in FORWARD_HOST list) +# set -u -e if [[ "$FORWARD_HOST" != "" && "$FORWARD_PORTS" != "" ]]; then - FORWARD_HOST="${FORWARD_HOST%%:*}" + # Remove port numbers from the FORWARD_HOST list, in case the client passed + # them. Sometimes, it's easier to erase the port numbers here than on the + # client's side, where FORWARD_HOST is passed as host:ignored_port from some + # other data source. + FORWARD_HOST=$(echo "$FORWARD_HOST" | sed -E 's/:[0-9]+//g') tcp_lines=() udp_lines=() for spec in $FORWARD_PORTS; do + hosts=$(echo "$FORWARD_HOST" | xargs) port=${spec%%/*} proto=${spec##*/} - [[ "$proto" == "$port" ]] && proto=tcp + if [[ "$proto" == "$port" ]]; then + proto=tcp + fi + if [[ "$proto" == "tcp-backup" ]]; then + proto="tcp" + hosts=$(echo "$FORWARD_HOST" | awk '{for(i=NF;i>0;i--) printf "%s ", $i; print ""}' | xargs) + fi if [[ "$proto" == udp ]]; then - udp_lines+=("127.0.0.1 $port/$proto $FORWARD_HOST $port/$proto") + # UDP forwarding doesn't support backup servers, so use the first host. + udp_lines+=("127.0.0.1 $port/$proto ${hosts%% *} $port/$proto") else tcp_lines+=("listen ${proto}_${port}") tcp_lines+=(" bind 127.0.0.1:$port") - # ipv4 is needed for e.g. host.docker.internal - tcp_lines+=(" server server1 $FORWARD_HOST:$port resolvers res resolve-prefer ipv4") + i=0 + for host in $hosts; do + # ipv4 is needed for e.g. host.docker.internal + tcp_line=" server server$i $host:$port resolvers res resolve-prefer ipv4 check inter 10s fall 6 rise 6" + if [[ $i == 0 ]]; then + tcp_lines+=("$tcp_line") + else + tcp_lines+=("$tcp_line backup") + fi + i=$((i+1)) + done tcp_lines+=(" mode $proto") fi done diff --git a/docker/ci-scaler/Dockerfile b/docker/ci-scaler/Dockerfile index 5391976..67af232 100644 --- a/docker/ci-scaler/Dockerfile +++ b/docker/ci-scaler/Dockerfile @@ -4,6 +4,10 @@ FROM $BASE_IMAGE ENV GH_TOKEN="" ENV ASGS="" ENV DOMAIN="" +ENV DYNAMODB_TABLE_PREFIX="" +ENV AWS_ENDPOINT_URL="" +ENV AWS_ACCESS_KEY_ID="" +ENV AWS_SECRET_ACCESS_KEY="" ENV TZ="" ENV DEBIAN_FRONTEND=noninteractive @@ -12,7 +16,7 @@ RUN true \ && apt-get update -y \ && apt-get install -y --no-install-recommends \ awscli jq rsync python3 python3-yaml rsyslog systemctl tzdata gosu less mc git curl wget pv psmisc unzip vim nano telnet net-tools apt-transport-https ca-certificates locales gnupg lsb-release \ - && curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | gpg --dearmor -o /usr/share/keyrings/githubcli-archive-keyring.gpg \ + && curl -fsSL --retry 3 --retry-all-errors https://cli.github.com/packages/githubcli-archive-keyring.gpg | gpg --dearmor -o /usr/share/keyrings/githubcli-archive-keyring.gpg \ && echo "deb [signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \ && apt-get update -y \ && apt-get install -y --no-install-recommends gh \ diff --git a/docker/ci-scaler/README.md b/docker/ci-scaler/README.md index e7c3ff5..708d1f2 100644 --- a/docker/ci-scaler/README.md +++ b/docker/ci-scaler/README.md @@ -18,6 +18,12 @@ To use: "{owner}/{repo}:{label}:{asg_name}" - `DOMAIN`: domain of API Gateway which listens for GitHub webhook requests via HTTPS and forwards all requests to this container's port 8088 + - `DYNAMODB_TABLE_PREFIX`: if set, use DynamoDB tables to store the state + across webhook requests; useful when running multiple instances of + ci-scaler + - `AWS_ENDPOINT_URL`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`: + optionally, you may pass these variables to access AWS API; used in + debugging mostly - `TZ` (optional): timezone name Example for docker compose: @@ -33,6 +39,7 @@ services: - GH_TOKEN - ASGS - DOMAIN + - DYNAMODB_TABLE_PREFIX - TZ ``` diff --git a/docker/ci-scaler/guest/entrypoint.99-run.sh b/docker/ci-scaler/guest/entrypoint.99-run.sh index a643f76..5047aea 100644 --- a/docker/ci-scaler/guest/entrypoint.99-run.sh +++ b/docker/ci-scaler/guest/entrypoint.99-run.sh @@ -7,7 +7,8 @@ set -u -e if [[ "$ASGS" != "" ]]; then exec python3 ./scaler/main.py \ --asgs="$ASGS" \ - --domain="$DOMAIN" + --domain="$DOMAIN" \ + --dynamodb-table-prefix="$DYNAMODB_TABLE_PREFIX" else exec sleep 1000000000 fi diff --git a/docker/ci-scaler/guest/scaler/api_aws.py b/docker/ci-scaler/guest/scaler/api_aws.py index 7ce58c2..d88755c 100644 --- a/docker/ci-scaler/guest/scaler/api_aws.py +++ b/docker/ci-scaler/guest/scaler/api_aws.py @@ -64,9 +64,19 @@ def aws( input: str | None = None, ) -> str | None: region = aws_region() + endpoint_url = os.environ.get("AWS_ENDPOINT_URL") + if args[0] == "dynamodb" and endpoint_url: + region = "us-east-1" if not region: return None - return check_output(["aws", f"--region={region}", *args], input=input) + cmd = [ + "aws", + f"--region={region}", + *([f"--endpoint-url={endpoint_url}"] if endpoint_url else ()), + *args, + ] + out = check_output(cmd, input=input) + return out def aws_json( diff --git a/docker/ci-scaler/guest/scaler/api_gh.py b/docker/ci-scaler/guest/scaler/api_gh.py index b846388..f526a5d 100644 --- a/docker/ci-scaler/guest/scaler/api_gh.py +++ b/docker/ci-scaler/guest/scaler/api_gh.py @@ -9,7 +9,7 @@ import traceback import yaml from helpers import Runner, RateLimits, check_output -from typing import Any, cast +from typing import Any, Literal, cast def gh( @@ -98,7 +98,7 @@ def gh_webhook_ensure_exists( url: str, secret: str, events: list[str], -): +) -> Literal["created", "already_exists"]: try: gh_api( "-XPOST", @@ -113,9 +113,11 @@ def gh_webhook_ensure_exists( "active": True, }, ) + return "created" except subprocess.CalledProcessError as e: - if "Hook already exists" not in e.stdout: - raise + if "Hook already exists" in e.stdout: + return "already_exists" + raise def gh_webhook_ensure_absent( diff --git a/docker/ci-scaler/guest/scaler/handler_idle_runners.py b/docker/ci-scaler/guest/scaler/handler_idle_runners.py index 5089c93..33bd294 100644 --- a/docker/ci-scaler/guest/scaler/handler_idle_runners.py +++ b/docker/ci-scaler/guest/scaler/handler_idle_runners.py @@ -12,10 +12,9 @@ AsgSpec, Runner, RunnersRegistry, - ExpiringDict, logged_result, ) -from typing import Literal +from storage import StorageFactory REVISIT_TERMINATED_INSTANCE_SEC = datetime.timedelta(minutes=10).total_seconds() @@ -26,12 +25,15 @@ def __init__( *, asg_spec: AsgSpec, max_idle_age_sec: int, + storage: StorageFactory, ): super().__init__(asg_spec=asg_spec) self.max_idle_age_sec = max_idle_age_sec self.idle_runners = RunnersRegistry() - self.terminated_instance_ids = ExpiringDict[str, Literal[True]]( + self.terminated_instance_ids = storage.create( + bool, ttl=REVISIT_TERMINATED_INSTANCE_SEC, + name="terminated-instance-ids", ) def handle(self, runners: list[Runner]) -> None: diff --git a/docker/ci-scaler/guest/scaler/handler_webhooks.py b/docker/ci-scaler/guest/scaler/handler_webhooks.py index 6fda27c..8b0df9d 100644 --- a/docker/ci-scaler/guest/scaler/handler_webhooks.py +++ b/docker/ci-scaler/guest/scaler/handler_webhooks.py @@ -12,7 +12,6 @@ gh_predict_workflow_labels, gh_webhook_ensure_absent, gh_webhook_ensure_exists, - gh_webhook_ping, ) from api_aws import ( DRY_RUN_MSG, @@ -20,20 +19,21 @@ aws_cloudwatch_put_metric_data, ) from helpers import ( - ExpiringDict, PostJsonHttpRequestHandler, AsgSpec, log, logged_result, ) -from typing import Any, Literal, cast +from storage import StorageFactory +from typing import Any, Literal +from zlib import crc32 DUPLICATED_EVENTS_TTL = 3600 JOB_TIMING_TTL = 3600 * 2 WORKFLOW_TTL = 3600 -WORKFLOW_RUN_EVENT = "workflow_run" -WORKFLOW_JOB_EVENT = "workflow_job" +WORKFLOW_RUN_EVENT = "workflow_run" # https://docs.github.com/en/webhooks/webhook-events-and-payloads#workflow_run +WORKFLOW_JOB_EVENT = "workflow_job" # https://docs.github.com/en/webhooks/webhook-events-and-payloads#workflow_job IGNORE_KEYS = [ "zen", "hook_id", @@ -44,42 +44,51 @@ "action", ] URL_PATH = "/ci-storage" -SERVICE_ACTION_INTERVAL_SEC = 10 +WEBHOOK_ENSURE_EXISTS_INTERVAL_SEC = 60 @dataclasses.dataclass class Webhook: url: str - last_delivery_at: int | None - - -@dataclasses.dataclass -class ServiceAction: - prev_at: int - iteration: int = 0 + ensure_exists_at: int @dataclasses.dataclass class JobTiming: job_id: int - queued_at: float | None = None - started_at: float | None = None - completed_at: float | None = None - bumped: set[str] = dataclasses.field(default_factory=set) + queued_at: int | None = None + started_at: int | None = None + completed_at: int | None = None + bumped: list[str] = dataclasses.field(default_factory=list) class HandlerWebhooks: - def __init__(self, *, domain: str, asg_specs: list[AsgSpec]): + def __init__( + self, + *, + domain: str, + asg_specs: list[AsgSpec], + storage: StorageFactory, + ): self.domain = domain self.asg_specs = asg_specs self.webhooks: dict[str, Webhook] = {} - self.service_action = ServiceAction(prev_at=int(time.time())) self.secret = gh_get_webhook_secret() - self.duplicated_events = ExpiringDict[tuple[int, str], float]( - ttl=DUPLICATED_EVENTS_TTL + self.duplicated_events = storage.create( + int, + ttl=DUPLICATED_EVENTS_TTL, + name="duplicated-events", + ) + self.job_timings = storage.create( + JobTiming, + ttl=JOB_TIMING_TTL, + name="job-timings", + ) + self.workflows = storage.create( + dict[str, Any], + ttl=WORKFLOW_TTL, + name="workflows", ) - self.job_timings = ExpiringDict[int, JobTiming](ttl=JOB_TIMING_TTL) - self.workflows = ExpiringDict[str, dict[str, Any]](ttl=WORKFLOW_TTL) this = self class RequestHandler(PostJsonHttpRequestHandler): @@ -91,7 +100,7 @@ def handle_POST_json(self, data: dict[str, Any], data_bytes: bytes): def __enter__(self): if not self.secret: return self - for repository in list(set(asg_spec.repository for asg_spec in self.asg_specs)): + for repository in set(asg_spec.repository for asg_spec in self.asg_specs): url = f"https://{self.domain}{URL_PATH}" with logged_result(doing=f"Registering webhook for {repository}: {url}"): gh_webhook_ensure_exists( @@ -100,7 +109,10 @@ def __enter__(self): secret=self.secret, events=[WORKFLOW_RUN_EVENT, WORKFLOW_JOB_EVENT], ) - self.webhooks[repository] = Webhook(url=url, last_delivery_at=None) + self.webhooks[repository] = Webhook( + url=url, + ensure_exists_at=int(time.time()), + ) return self def __exit__(self, *_: Any): @@ -112,20 +124,25 @@ def __exit__(self, *_: Any): gh_webhook_ensure_absent(repository=repository, url=webhook.url) def service_actions(self): - now = int(time.time()) - if now > self.service_action.prev_at + SERVICE_ACTION_INTERVAL_SEC: - i = self.service_action.iteration - self.service_action.iteration += 1 - self.service_action.prev_at = now - webhooks = [*self.webhooks.items()] - if webhooks: - repository, webhook = webhooks[i % len(webhooks)] - if webhook.last_delivery_at is None: - with logged_result( - swallow=True, - doing=f"Sending additional PING to webhook for {repository}: {webhook.url}", - ): - gh_webhook_ping(repository=repository, url=webhook.url) + if not self.secret: + return + # We re-register webhooks periodically, since they may be de-registered + # when one of the ci-scaler hosts terminates. In this case, all other + # remaining hosts (in the load balancing group) will recreate the + # webhooks eventually. + for repository, webhook in self.webhooks.items(): + now = int(time.time()) + if now > webhook.ensure_exists_at + WEBHOOK_ENSURE_EXISTS_INTERVAL_SEC: + webhook.ensure_exists_at = now + with logged_result(swallow=True): + res = gh_webhook_ensure_exists( + repository=repository, + url=webhook.url, + secret=self.secret, + events=[WORKFLOW_RUN_EVENT, WORKFLOW_JOB_EVENT], + ) + if res == "created": + log(f"Re-created webhook for {repository}: {webhook.url}") def handle( self, @@ -133,41 +150,52 @@ def handle( data: dict[str, Any], data_bytes: bytes, ): - action = data.get("action") - run_payload = data.get(WORKFLOW_RUN_EVENT) - job_payload = data.get(WORKFLOW_JOB_EVENT) - # For local debugging only! Allows to simulate a webhook with just # querying an URL that includes the repo name and label: # - /workflow_run/owner/repo/label # - /workflow_job/owner/repo/label/{queued|in_progress|completed}/job_id + extra_debug_labels: dict[str, int] = {} if ( handler.client_address[0] == "127.0.0.1" - and not action - and not run_payload - and not job_payload + and not data.get("action") + and not data.get(WORKFLOW_RUN_EVENT) + and not data.get(WORKFLOW_JOB_EVENT) ): if match := re.match( rf"^/{WORKFLOW_RUN_EVENT}/([^/]+/[^/]+)/([^/]+)/?$", handler.path, ): - return self._handle_workflow_run_in_progress( - handler=handler, - repository=match.group(1), - labels={match.group(2): 1}, - ) + extra_debug_labels[match.group(2)] = 1 + data = { + "action": "in_progress", + "repository": { + "full_name": match.group(1), + }, + WORKFLOW_RUN_EVENT: { + "id": crc32(match.group(2).encode()), + "run_attempt": 1, + "name": "test", + "head_sha": "", + "path": "/.github/workflows/ci.yml", + }, + } elif match := re.match( rf"^/{WORKFLOW_JOB_EVENT}/([^/]+/[^/]+)/([^/]+)/([^/]+)/([^/]+)/?$", handler.path, ): - return self._handle_workflow_job_timing( - handler=handler, - repository=match.group(1), - labels={match.group(2): 1}, - action=cast(Any, match.group(3)), - job_id=int(match.group(4)), - name=None, - ) + data = { + "repository": { + "full_name": match.group(1), + }, + WORKFLOW_JOB_EVENT: { + "id": int(match.group(4)), + "run_attempt": 1, + "name": "test", + "labels": [match.group(2)], + }, + "action": match.group(3), + "job_id": int(match.group(4)), + } else: return handler.send_error( 404, @@ -177,17 +205,28 @@ def handle( + f"/{WORKFLOW_JOB_EVENT}/owner/repo/label/{'{queued|in_progress|completed}'}/job_id" + f", but got {handler.path}", ) + else: + assert self.secret + error = verify_signature( + secret=self.secret, + headers=handler.headers, + data_bytes=data_bytes, + ) + if error: + return handler.send_error(403, error) + action = data.get("action") repository: str | None = data.get("repository", {}).get("full_name", None) - if repository in self.webhooks: - self.webhooks[repository].last_delivery_at = int(time.time()) + keys = [k for k in data.keys() if k not in IGNORE_KEYS] + workflow_run = data.get(WORKFLOW_RUN_EVENT) + workflow_job = data.get(WORKFLOW_JOB_EVENT) name = ( - str(run_payload.get("name")) - if run_payload - else str(job_payload.get("name")) if job_payload else None + str(workflow_run.get("name")) + if workflow_run + else str(workflow_job.get("name")) if workflow_job else None ) - keys = [k for k in data.keys() if k not in IGNORE_KEYS] + if keys: handler.log_suffix = ( f"{{{','.join(keys)}}}" @@ -201,23 +240,16 @@ def handle( if not repository: return handler.send_json(202, message="ignoring event with no repository") - assert self.secret - error = verify_signature( - secret=self.secret, - headers=handler.headers, - data_bytes=data_bytes, - ) - if error: - return handler.send_error(403, error) - - if run_payload: + # This event is used for increasing the number of runners. + if workflow_run: if action != "requested" and action != "in_progress": return handler.send_json( 202, message='ignoring action != ["requested", "in_progress"]', ) - event_key = (int(run_payload["id"]), str(run_payload["run_attempt"])) + event_key = f"{workflow_run['id']}:{workflow_run['run_attempt']}" + handler.log_suffix += f" id={event_key}" processed_at = self.duplicated_events.get(event_key) if processed_at: return handler.send_json( @@ -225,8 +257,8 @@ def handle( message=f"ignoring event that has already been processed at {time.ctime(processed_at)}", ) - head_sha = str(run_payload["head_sha"]) - path = str(run_payload["path"]) + head_sha = str(workflow_run["head_sha"]) + path = str(workflow_run["path"]) message = f"{repository}{event_key}: downloading {os.path.basename(path)} and parsing jobs list" try: cache_key = f"{repository}:{path}" @@ -240,7 +272,7 @@ def handle( self.workflows[cache_key] = workflow else: message += f" (cached)" - labels = gh_predict_workflow_labels( + labels = extra_debug_labels | gh_predict_workflow_labels( workflow=workflow, known_labels=[ asg_spec.label @@ -255,21 +287,23 @@ def handle( except Exception as e: return handler.send_error(500, f"{message} failed: {e}") - self.duplicated_events[event_key] = time.time() + self.duplicated_events[event_key] = int(time.time()) return self._handle_workflow_run_in_progress( handler=handler, repository=repository, labels=labels, ) - if job_payload: - if action != "queued" and action != "in_progress" and action != "completed": + # This event is only used for statistics about timing. + if workflow_job: + allowed_actions = ["queued", "in_progress", "completed"] + if action not in allowed_actions: return handler.send_json( 202, - message='ignoring action != ["queued", "in_progress", "completed"]', + message=f"ignoring action != {allowed_actions}", ) - event_key = (int(job_payload["id"]), action) + event_key = f"{workflow_job['id']}:{workflow_job['run_attempt']}:{action}" processed_at = self.duplicated_events.get(event_key) if processed_at: return handler.send_json( @@ -277,16 +311,17 @@ def handle( message=f"ignoring event that has already been processed at {time.ctime(processed_at)}", ) - self.duplicated_events[event_key] = time.time() + self.duplicated_events[event_key] = int(time.time()) return self._handle_workflow_job_timing( handler=handler, repository=repository, - labels={label: 1 for label in job_payload["labels"]}, + labels={label: 1 for label in workflow_job["labels"]}, action=action, - job_id=int(job_payload["id"]), + job_id=int(workflow_job["id"]), name=name, ) + # Unrecognized event, skipping. return handler.send_json( 202, message=f"ignoring event with no {WORKFLOW_RUN_EVENT} and {WORKFLOW_JOB_EVENT}", @@ -345,10 +380,11 @@ def _handle_workflow_job_timing( message=f"ignoring event, since no matching auto-scaling group(s) found for repository {repository} and labels {[*labels.keys()]}", ) - timing = self.job_timings.get(job_id) or JobTiming(job_id=job_id) - self.job_timings[job_id] = timing + timing = self.job_timings.get(str(job_id)) + if timing is None: + timing = JobTiming(job_id=job_id) - now = time.time() + now = int(time.time()) if action == "queued": timing.queued_at = now elif action == "in_progress": @@ -380,7 +416,9 @@ def _handle_workflow_job_timing( for metric in timing.bumped: metrics.pop(metric, None) - timing.bumped.update(metrics.keys()) + timing.bumped.extend(metrics.keys()) + + self.job_timings[str(job_id)] = timing if metrics: job_name = ( @@ -416,7 +454,14 @@ def _handle_workflow_job_timing( return handler.send_json( 200, - message=f"processed event for job_id={job_id}: {asg_spec}", + message=( + f"logged timing event for job_id={job_id}: {asg_spec}, " + + ( + ", ".join(f"{k}:{v}" for k, v in metrics.items()) + if metrics + else "no metrics yet" + ) + ), ) diff --git a/docker/ci-scaler/guest/scaler/helpers.py b/docker/ci-scaler/guest/scaler/helpers.py index 7065294..53cca4e 100644 --- a/docker/ci-scaler/guest/scaler/helpers.py +++ b/docker/ci-scaler/guest/scaler/helpers.py @@ -13,7 +13,7 @@ from http import HTTPStatus from json import dumps, loads, JSONDecodeError, decoder from types import TracebackType -from typing import Any, Callable, Generic, Iterable, Literal, TypeVar +from typing import Any, Callable, Iterable, Literal C_RED = "\033[1;31m" C_GRAY = "\033[1;30m" @@ -315,48 +315,3 @@ def send_error( log(explain, error=True) -K = TypeVar("K") -V = TypeVar("V") - - -class ExpiringDict(Generic[K, V]): - def __init__(self, *, ttl: float): - self.ttl = ttl - self._store: dict[K, V] = {} - self._times: dict[K, float] = {} - - def _is_expired(self, key: K) -> bool: - return time.time() - self._times.get(key, 0) > self.ttl - - def _garbage_collect(self) -> None: - keys_to_delete = [key for key in self._store if self._is_expired(key)] - for key in keys_to_delete: - del self._store[key] - del self._times[key] - - def __setitem__(self, key: K, value: V): - self._garbage_collect() - self._store[key] = value - self._times[key] = time.time() - - def __getitem__(self, key: K) -> V: - if key not in self._store or self._is_expired(key): - raise KeyError(f"Key '{key}' not found or expired") - return self._store[key] - - def __delitem__(self, key: K): - if key in self._store: - del self._store[key] - del self._times[key] - - def __contains__(self, key: K): - return key in self._store and not self._is_expired(key) - - def __repr__(self): - return f"ExpiringDict({{k: v for k, v in self.store.items() if not self._is_expired(k)}})" - - def get(self, key: K, default: V | None = None) -> V | None: - try: - return self[key] - except KeyError: - return default diff --git a/docker/ci-scaler/guest/scaler/main.py b/docker/ci-scaler/guest/scaler/main.py index e5807cb..cef5b38 100755 --- a/docker/ci-scaler/guest/scaler/main.py +++ b/docker/ci-scaler/guest/scaler/main.py @@ -20,6 +20,7 @@ Runner, wrap_main, ) +from storage import StorageFactory def main(): @@ -81,9 +82,15 @@ def main(): parser.add_argument( "--max-offline-age-sec", type=int, - default=120, + default=1200, help="offline runners will be de-registered after this time", ) + parser.add_argument( + "--dynamodb-table-prefix", + type=str, + default=None, + help="if set, use DynamoDB shared storage with this table name prefix (useful when multiple ci-scaler instances are running)", + ) args = parser.parse_args() port = int(args.port) @@ -92,6 +99,7 @@ def main(): poll_interval_sec = int(args.poll_interval_sec) max_idle_age_sec = int(args.max_idle_age_sec) max_offline_age_sec = int(args.max_offline_age_sec) + storage = StorageFactory(dynamodb_table_prefix=args.dynamodb_table_prefix or None) handler_cloudwatch_rate_limits = HandlerCloudWatchRateLimits() handlers_asg: dict[AsgSpec, list[AsgHandler]] = {} @@ -102,6 +110,7 @@ def main(): HandlerIdleRunners( asg_spec=asg_spec, max_idle_age_sec=max_idle_age_sec, + storage=storage, ), HandlerOfflineRunners( asg_spec=asg_spec, @@ -140,7 +149,11 @@ def poll_thread(): handler_cloudwatch_rate_limits.handle() time.sleep(poll_interval_sec) - with HandlerWebhooks(domain=domain, asg_specs=asg_specs) as webhooks: + with HandlerWebhooks( + domain=domain, + asg_specs=asg_specs, + storage=storage, + ) as webhooks: with socketserver.TCPServer( ("", port), webhooks.RequestHandler, diff --git a/docker/ci-scaler/guest/scaler/storage.py b/docker/ci-scaler/guest/scaler/storage.py new file mode 100644 index 0000000..513cb30 --- /dev/null +++ b/docker/ci-scaler/guest/scaler/storage.py @@ -0,0 +1,219 @@ +import dataclasses +import math +import subprocess +import time +from api_aws import aws, aws_json +from json import dumps, loads +from typing import Any, Generic, Type, TypeVar + +V = TypeVar("V") + + +# +# A dict-like class in memory with TTL-based expiration. +# +class MemoryDict(Generic[V]): + def __init__(self, *, ttl: float): + self.ttl = ttl + self._store: dict[str, V] = {} + self._times: dict[str, float] = {} + + def _is_expired(self, key: str) -> bool: + return time.time() - self._times.get(key, 0) > self.ttl + + def _garbage_collect(self) -> None: + keys_to_delete = [key for key in self._store if self._is_expired(key)] + for key in keys_to_delete: + del self._store[key] + del self._times[key] + + def __setitem__(self, key: str, value: V): + self._garbage_collect() + self._store[key] = value + self._times[key] = time.time() + + def __getitem__(self, key: str) -> V: + if key not in self._store or self._is_expired(key): + raise KeyError(f"Key '{key}' not found or expired") + return self._store[key] + + def __delitem__(self, key: str): + if key in self._store: + del self._store[key] + del self._times[key] + + def __contains__(self, key: str): + return key in self._store and not self._is_expired(key) + + def __repr__(self): + return ( + "MemoryDict(" + + str({k: v for k, v in self._store.items() if not self._is_expired(k)}) + + ")" + ) + + def get(self, key: str, default: V | None = None) -> V | None: + try: + return self[key] + except KeyError: + return default + + +# +# A dict-like class backed by DynamoDB with TTL-based expiration. Drop-in +# replacement for MemoryDict when shared state across instances is needed. +# Automatically creates the DynamoDB table if it doesn't exist. +# +class SharedDict(Generic[V]): + def __init__(self, value_type: Type[V], *, ttl: float, table: str): + self.ttl = ttl + self.table = table + self.value_type = value_type + self._ensure_table() + + def _ensure_table(self) -> None: + try: + aws( + "dynamodb", + "describe-table", + f"--table-name={self.table}", + ) + except subprocess.CalledProcessError as e: + if "ResourceNotFoundException" not in (e.stderr or ""): + raise + aws( + "dynamodb", + "create-table", + f"--table-name={self.table}", + "--attribute-definitions", + "AttributeName=pk,AttributeType=S", + "--key-schema", + "AttributeName=pk,KeyType=HASH", + "--billing-mode=PAY_PER_REQUEST", + ) + for _ in range(60): + res = aws_json( + "dynamodb", + "describe-table", + f"--table-name={self.table}", + ) + if res and res["Table"]["TableStatus"] == "ACTIVE": + break + time.sleep(5) + aws( + "dynamodb", + "update-time-to-live", + f"--table-name={self.table}", + "--time-to-live-specification", + "Enabled=true,AttributeName=ttl", + ) + + def __setitem__(self, key: str, value: V): + ttl_epoch = math.floor(time.time() + self.ttl) + item_json = dumps( + { + "pk": {"S": key}, + "val": {"S": dumps(value, default=_json_default)}, + "ttl": {"N": str(ttl_epoch)}, + } + ) + aws( + "dynamodb", + "put-item", + f"--table-name={self.table}", + f"--item={item_json}", + ) + + def _deserialize(self, raw: Any) -> V: + if dataclasses.is_dataclass(self.value_type) and isinstance(raw, dict): + return self.value_type(**raw) + else: + return raw + + def _get_item(self, key: str) -> V | None: + res = aws_json( + "dynamodb", + "get-item", + f"--table-name={self.table}", + f"--key={dumps({'pk': {'S': key}})}", + "--consistent-read", + ) + if res is None or "Item" not in res: + return None + item = res["Item"] + ttl_val = int(item["ttl"]["N"]) + if ttl_val < time.time(): + return None + return self._deserialize(loads(item["val"]["S"])) + + def __getitem__(self, key: str) -> V: + value = self._get_item(key) + if value is None: + raise KeyError(f"Key '{key}' not found or expired") + return value + + def __delitem__(self, key: str): + aws( + "dynamodb", + "delete-item", + f"--table-name={self.table}", + f"--key={dumps({'pk': {'S': key}})}", + ) + + def __contains__(self, key: str): + return self._get_item(key) is not None + + def __repr__(self): + now = time.time() + res = aws_json( + "dynamodb", + "scan", + f"--table-name={self.table}", + "--consistent-read", + ) + items = {} + if res and "Items" in res: + for item in res["Items"]: + ttl_val = int(item["ttl"]["N"]) + if ttl_val >= now: + items[item["pk"]["S"]] = self._deserialize(loads(item["val"]["S"])) + return f"SharedDict({self.table}, {items})" + + def get(self, key: str, default: V | None = None) -> V | None: + try: + return self[key] + except KeyError: + return default + + +# +# Creates either MemoryDict or SharedDict instances depending on whether a +# DynamoDB table prefix is configured. +# +class StorageFactory: + def __init__(self, *, dynamodb_table_prefix: str | None = None): + self.dynamodb_table_prefix = dynamodb_table_prefix + + def create( + self, + value_type: Type[V], + *, + ttl: float, + name: str, + ) -> MemoryDict[V] | SharedDict[V]: + if self.dynamodb_table_prefix is not None: + return SharedDict[V]( + value_type, + ttl=ttl, + table=f"{self.dynamodb_table_prefix}-{name}", + ) + else: + return MemoryDict[V](ttl=ttl) + + +def _json_default(obj: Any) -> Any: + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return dataclasses.asdict(obj) + if isinstance(obj, set): + return sorted(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") diff --git a/docker/ci-scaler/guest/scaler/tests/test_api_gh.py b/docker/ci-scaler/guest/scaler/tests/test_api_gh.py index aff5a95..86a76e3 100644 --- a/docker/ci-scaler/guest/scaler/tests/test_api_gh.py +++ b/docker/ci-scaler/guest/scaler/tests/test_api_gh.py @@ -6,7 +6,6 @@ gh_predict_workflow_labels, gh_runner_ensure_absent, gh_webhook_ensure_absent, - gh_webhook_ping, gh, ) from unittest import TestCase @@ -37,9 +36,6 @@ def test_gh_get_webhook_secret(self): def test_gh_webhook_ensure_exists(self): pass - def test_gh_webhook_ping(self): - gh_webhook_ping(repository=self.repository, url="https://example.com") - def test_gh_webhook_ensure_absent(self): gh_webhook_ensure_absent(repository=self.repository, url="https://example.com") diff --git a/docker/ci-scaler/root/entrypoint.01-validate.sh b/docker/ci-scaler/root/entrypoint.01-validate.sh index 89241ab..a086d89 100644 --- a/docker/ci-scaler/root/entrypoint.01-validate.sh +++ b/docker/ci-scaler/root/entrypoint.01-validate.sh @@ -22,3 +22,8 @@ if [[ "${DOMAIN:=}" != "" && "$DOMAIN" != *.* ]]; then say "If DOMAIN is set, it should be a fully qualified domain name." exit 1 fi + +export DYNAMODB_TABLE_PREFIX=${DYNAMODB_TABLE_PREFIX:-} +export AWS_ENDPOINT_URL=${AWS_ENDPOINT_URL:-} +export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} +export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} diff --git a/docker/compose-up-dev.sh b/docker/compose-up-dev.sh index 820238b..c6decf7 100755 --- a/docker/compose-up-dev.sh +++ b/docker/compose-up-dev.sh @@ -14,7 +14,7 @@ fi GH_TOKEN=$(gh auth token) \ GH_REPOSITORY=$(gh repo view --json owner,name -q '.owner.login + "/" + .name') \ GH_LABELS=ci-storage-dev \ -FORWARD_HOST=host.docker.internal \ +FORWARD_HOST="host.docker.internal:42 example.com" \ TZ=America/Los_Angeles \ BTIME="$btime" \ ASGS=$(gh repo view --json owner,name -q '.owner.login + "/" + .name'):ci-storage-dev:myasg \ diff --git a/docker/compose.yml b/docker/compose.yml index eefa0ba..9870c2b 100644 --- a/docker/compose.yml +++ b/docker/compose.yml @@ -22,6 +22,28 @@ services: - DOMAIN - TZ + ci-scaler-shared: + build: + context: ci-scaler + dockerfile: Dockerfile + stop_grace_period: 1m + healthcheck: + test: ["CMD", "bash", "-c", "netstat -ltn | grep -c :8088"] + interval: 1s + timeout: 3s + retries: 10 + ports: + - 18089:8088 + environment: + - GH_TOKEN + - ASGS + - DOMAIN + - DYNAMODB_TABLE_PREFIX=ci-storage + - AWS_ENDPOINT_URL=http://dynamodb:8000 + - AWS_ACCESS_KEY_ID=debug + - AWS_SECRET_ACCESS_KEY=debug + - TZ + ci-storage: build: context: ci-storage @@ -57,7 +79,7 @@ services: - GH_LABELS - TZ - FORWARD_HOST - - FORWARD_PORTS=10022/tcp 8125/udp + - FORWARD_PORTS=10022/tcp 80/tcp-backup 8125/udp - CI_STORAGE_HOST=127.0.0.1:10022 - BTIME - DEBUG_SHUTDOWN_DELAY_SEC=1 @@ -68,6 +90,9 @@ services: tmpfs: - /mnt:exec + dynamodb: + image: amazon/dynamodb-local + volumes: ci-storage-mnt: external: false