diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..e1c9a2c --- /dev/null +++ b/.dockerignore @@ -0,0 +1,64 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info/ +dist/ +build/ +.eggs/ +*.egg + +# Virtual environments +.venv/ +venv/ +ENV/ +env/ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +*.cover + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Git +.git/ +.gitignore +.gitattributes + +# CI/CD +.github/ +.gitlab-ci.yml + +# Documentation +*.md +docs/ + +# Docker +Dockerfile* +.dockerignore +docker-compose*.yml +build.sh +push.sh + +# Database +db/migrations/ +*.sql + +# Environment +.env +.env.local +.env.*.local + +# Tests +tests/ +pytest.ini diff --git a/Dockerfile b/Dockerfile index 09b858d..4c19544 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,19 @@ -FROM python:slim-bookworm +FROM python:3.12-slim-bookworm AS builder + +WORKDIR /app + +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv + +ENV UV_COMPILE_BYTECODE=1 \ + UV_LINK_MODE=copy + +COPY requirements.txt . + +RUN uv venv /app/.venv && \ + uv pip install -r requirements.txt && \ + uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - + +FROM python:3.12-slim-bookworm ARG VERSION ARG SEMVER_CORE @@ -6,31 +21,25 @@ ARG COMMIT_SHA ARG GITHUB_REPO ARG BUILD_DATE -ENV VERSION=${VERSION} -ENV SEMVER_CORE=${SEMVER_CORE} -ENV COMMIT_SHA=${COMMIT_SHA} -ENV BUILD_DATE=${BUILD_DATE} -ENV GITHUB_REPO=${GITHUB_REPO} - LABEL org.opencontainers.image.source=${GITHUB_REPO} LABEL org.opencontainers.image.created=${BUILD_DATE} LABEL org.opencontainers.image.version=${VERSION} LABEL org.opencontainers.image.revision=${COMMIT_SHA} -RUN set -e \ - && useradd -ms /bin/bash -d /app app +RUN useradd -ms /bin/bash -d /app app WORKDIR /app -USER app -ENV PATH="$PATH:/app/.local/bin/" +COPY --from=builder --chown=app:app /app/.venv /app/.venv -COPY requirements.txt /app/ +COPY --chown=app:app . /app/ -RUN set -e \ - && pip install --no-cache-dir -r /app/requirements.txt --break-system-packages \ - && opentelemetry-bootstrap -a install +USER app -COPY --chown=app:app . /app/ +ENV VERSION=${VERSION} \ + SEMVER_CORE=${SEMVER_CORE} \ + COMMIT_SHA=${COMMIT_SHA} \ + BUILD_DATE=${BUILD_DATE} \ + GITHUB_REPO=${GITHUB_REPO} CMD ["/app/run.sh"] diff --git a/app.py b/app.py index 8c981f0..53cb069 100755 --- a/app.py +++ b/app.py @@ -2,18 +2,21 @@ import asyncio import logging import os +import sys +import traceback import uuid from contextlib import asynccontextmanager from typing import Annotated import asyncpg -import blibs +import fastapi_structured_logging import httpx -from asgi_logger.middleware import AccessLoggerMiddleware from fastapi import FastAPI from fastapi import Header from fastapi import HTTPException -from fastapi.middleware import Middleware +from fastapi import Request +from fastapi import status +from fastapi.responses import JSONResponse from fastapi.responses import RedirectResponse import webhook @@ -24,16 +27,25 @@ from gitlab_model import PipelinePayload from periodic_cleanup import periodic_cleanup -# from fastapi.middleware.cors import CORSMiddleware - config = DefaultConfig() -# Configure logging -blibs.init_root_logger() -logger = logging.getLogger(__name__) -logging.getLogger("urllib3").setLevel(logging.ERROR) -logging.getLogger("msrest").setLevel(logging.ERROR) -logging.getLogger("msal").setLevel(logging.ERROR) +# Configure structured logging +log_format = os.getenv("LOG_FORMAT", "auto").lower() +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + +if log_format == "json": + fastapi_structured_logging.setup_logging(json_logs=True, log_level=log_level) +elif log_format == "line": + fastapi_structured_logging.setup_logging(json_logs=False, log_level=log_level) +else: + fastapi_structured_logging.setup_logging(log_level=log_level) + +logging.getLogger("uvicorn.error").disabled = True + +# Suppress traceback printing to stderr +traceback.print_exception = lambda *args, **kwargs: None + +logger = fastapi_structured_logging.get_logger() @asynccontextmanager @@ -52,22 +64,60 @@ async def lifespan(app: FastAPI): title="Teams Notifier gitlab-mr-api", version=os.environ.get("VERSION", "v0.0.0-dev"), lifespan=lifespan, - middleware=[ - Middleware( - AccessLoggerMiddleware, # type: ignore - format='%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(L)ss', # noqa # type: ignore - ) - ], ) -# Configure CORS -# app.add_middleware( -# CORSMiddleware, -# allow_origins=["*"], # Allows all origins -# allow_credentials=True, -# allow_methods=["*"], # Allows all methods -# allow_headers=["*"], # Allows all headers -# ) +app.add_middleware(fastapi_structured_logging.AccessLogMiddleware) + + +@app.exception_handler(asyncpg.UniqueViolationError) +async def database_uniqueviolation_handler(request: Request, exc: asyncpg.UniqueViolationError): + logger.error( + "database unique violation error", + error_type=type(exc).__name__, + error_detail=str(exc), + constraint=getattr(exc, "constraint_name", None), + path=request.url.path, + method=request.method, + ) + return JSONResponse( + status_code=status.HTTP_409_CONFLICT, + content={"detail": "Resource already exists", "error": str(exc)}, + ) + + +@app.exception_handler(asyncpg.PostgresError) +async def database_exception_handler(request: Request, exc: asyncpg.PostgresError): + logger.error( + "database error", + error_type=type(exc).__name__, + error_detail=str(exc), + sqlstate=getattr(exc, "sqlstate", None), + path=request.url.path, + method=request.method, + ) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Database error occurred"}, + ) + + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + exc_type, exc_value, exc_traceback = sys.exc_info() + + logger.error( + "unhandled exception", + error_type=type(exc).__name__, + error_detail=str(exc), + path=request.url.path, + method=request.method, + traceback="".join(traceback.format_exception(exc_type, exc_value, exc_traceback)), + ) + + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error"}, + ) @app.get("/", response_class=RedirectResponse, status_code=302) @@ -141,7 +191,12 @@ async def healthcheck(): result = await connection.fetchval("SELECT true FROM merge_request_ref") return {"ok": result} except Exception as e: - logger.exception(f"health check failed with {type(e)}: {e}") + logger.error( + "health check failed", + error_type=type(e).__name__, + error_detail=str(e), + exc_info=True, + ) raise HTTPException(status_code=500, detail=f"{type(e)}: {e}") diff --git a/cards/merge_request.yaml.j2 b/cards/merge_request.yaml.j2 index 1b01317..5a2f0df 100644 --- a/cards/merge_request.yaml.j2 +++ b/cards/merge_request.yaml.j2 @@ -1,6 +1,8 @@ $schema: https://adaptivecards.io/schemas/adaptive-card.json type: AdaptiveCard version: '1.5' +msteams: + width: Full fallbackText: {{ fallback }} speak: {{ fallback }} body: @@ -79,8 +81,8 @@ body: - type: Table showGridLines: false columns: - - width: 3 - - width: 7 + - width: 130px + - width: 1 firstRowAsHeaders: false rows: - type: TableRow @@ -167,7 +169,7 @@ body: verticalContentAlignment: Center items: - type: TextBlock - text: '{{ precalc.assignees | join(",") }}' + text: '{{ precalc.assignees | join(", ") }}' #size: Small wrap: true {% endif %} @@ -186,7 +188,7 @@ body: verticalContentAlignment: Center items: - type: TextBlock - text: '{{ precalc.reviewers | join(",") }}' + text: '{{ precalc.reviewers | join(", ") }}' #size: Small wrap: true {% endif %} @@ -205,7 +207,7 @@ body: verticalContentAlignment: Center items: - type: TextBlock - text: '{{ precalc.approvers | join(",") }}' + text: '{{ precalc.approvers | join(", ") }}' #size: Small wrap: true {% endif %} diff --git a/config.py b/config.py index a3e6ebd..a2bd4ec 100644 --- a/config.py +++ b/config.py @@ -16,6 +16,7 @@ class DefaultConfig: DATABASE_POOL_MAX_SIZE = int(os.environ.get("DATABASE_POOL_MAX_SIZE", "10")) LOG_QUERIES = os.environ.get("LOG_QUERIES", "") VALID_X_GITLAB_TOKEN = os.environ.get("VALID_X_GITLAB_TOKEN", "") + MESSAGE_DELETE_DELAY_SECONDS = int(os.environ.get("MESSAGE_DELETE_DELAY_SECONDS", "30")) _valid_tokens: list[str] def __init__(self): diff --git a/db.py b/db.py index 1fb2328..70b0ae6 100644 --- a/db.py +++ b/db.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import asyncio import json -import logging import urllib.parse from typing import Any from typing import Literal import asyncpg.connect_utils +import fastapi_structured_logging from pydantic import BaseModel from config import config @@ -15,7 +15,7 @@ from gitlab_model import MergeRequestPayload from gitlab_model import PipelinePayload -log = logging.getLogger(__name__) +log = fastapi_structured_logging.get_logger() __all__ = ["database", "dbh"] diff --git a/db/migrations/20250121000000_add_webhook_fingerprint_table.sql b/db/migrations/20250121000000_add_webhook_fingerprint_table.sql new file mode 100644 index 0000000..d6fb1d6 --- /dev/null +++ b/db/migrations/20250121000000_add_webhook_fingerprint_table.sql @@ -0,0 +1,11 @@ +-- migrate:up +CREATE TABLE gitlab_mr_api.webhook_fingerprint ( + fingerprint character varying(64) NOT NULL, + processed_at timestamp with time zone DEFAULT now() NOT NULL +); + +ALTER TABLE ONLY gitlab_mr_api.webhook_fingerprint + ADD CONSTRAINT webhook_fingerprint_pkey PRIMARY KEY (fingerprint); + +-- migrate:down +DROP TABLE gitlab_mr_api.webhook_fingerprint; diff --git a/periodic_cleanup.py b/periodic_cleanup.py index 7c5ef06..8aa1e48 100644 --- a/periodic_cleanup.py +++ b/periodic_cleanup.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 import asyncio import datetime -import logging import asyncpg +import fastapi_structured_logging import httpx from config import DefaultConfig from db import DatabaseLifecycleHandler -logger = logging.getLogger(__name__) +logger = fastapi_structured_logging.get_logger() signal = asyncio.Event() MAX_WAIT = 300 @@ -24,7 +24,8 @@ async def periodic_cleanup(config: DefaultConfig, database: DatabaseLifecycleHan async def _cleanup_task(config: DefaultConfig, database: DatabaseLifecycleHandler): - client = httpx.AsyncClient() + timeout = httpx.Timeout(10.0, connect=5.0) + client = httpx.AsyncClient(timeout=timeout) while True: # Cleanup message function goes here :) wait_sec = MAX_WAIT @@ -55,12 +56,32 @@ async def _cleanup_task(config: DefaultConfig, database: DatabaseLifecycleHandle ) logger.info("deleted message %s", record["msg_to_delete_id"]) except Exception as e: - logger.exception(f"Error processing record {record['msg_to_delete_id']}: {e}") + logger.error( + "error processing deletion record", + msg_to_delete_id=record["msg_to_delete_id"], + error_type=type(e).__name__, + error_detail=str(e), + exc_info=True, + ) + + deleted_fingerprints = await connection.fetch( + """DELETE FROM webhook_fingerprint + WHERE processed_at < NOW() - INTERVAL '24 hours' + RETURNING fingerprint""" + ) + if len(deleted_fingerprints) > 0: + logger.info("cleaned up old webhook fingerprints", count=len(deleted_fingerprints)) + value = await connection.fetchval("SELECT min(expire_at) FROM msg_to_delete") if value is not None: wait_sec = min(MAX_WAIT, (value - datetime.datetime.now(tz=datetime.UTC)).total_seconds()) except Exception as e: - logger.exception(e) + logger.error( + "cleanup task error", + error_type=type(e).__name__, + error_detail=str(e), + exc_info=True, + ) try: logger.debug(f"wait for signal or {wait_sec}s") await asyncio.wait_for(signal.wait(), wait_sec) @@ -74,4 +95,9 @@ async def _log_exception(awaitable): try: return await awaitable except Exception as e: - logger.exception(e) + logger.error( + "periodic cleanup unhandled exception", + error_type=type(e).__name__, + error_detail=str(e), + exc_info=True, + ) diff --git a/requirements.txt b/requirements.txt index c1549b5..3b70dd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,11 +3,15 @@ python-dateutil pydantic asyncpg httpx -blibs -fastapi[standard] +fastapi +uvicorn +uvloop +httptools jinja2 pyyaml -asgi-logger +fastapi-structured-logging +structlog +orjson opentelemetry-distro opentelemetry-exporter-otlp #opentelemetry-instrumentation-fastapi diff --git a/run.sh b/run.sh index a70a4fb..30f58ba 100755 --- a/run.sh +++ b/run.sh @@ -2,9 +2,9 @@ export OTEL_PYTHON_EXCLUDED_URLS=${OTEL_PYTHON_EXCLUDED_URLS:-healthz} -opentelemetry-instrument \ +exec /app/.venv/bin/opentelemetry-instrument \ --traces_exporter otlp \ --metrics_exporter otlp \ --logs_exporter otlp \ --service_name notifier-gitlab-mr-api \ - fastapi run + /app/.venv/bin/uvicorn app:app --host 0.0.0.0 --port 8000 diff --git a/webhook/emoji.py b/webhook/emoji.py index 11498c5..f7f3e93 100644 --- a/webhook/emoji.py +++ b/webhook/emoji.py @@ -65,7 +65,8 @@ async def update_message(mri: MergeRequestInfos, conversation_tokens: list[str]) ) connection: asyncpg.Connection - async with await database.acquire() as connection, httpx.AsyncClient() as client: + timeout = httpx.Timeout(10.0, connect=5.0) + async with await database.acquire() as connection, httpx.AsyncClient(timeout=timeout) as client: res = await connection.fetch( """ SELECT merge_request_message_ref_id, conversation_token, message_id diff --git a/webhook/merge_request.py b/webhook/merge_request.py index 2cbbe45..67c15fb 100644 --- a/webhook/merge_request.py +++ b/webhook/merge_request.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 import datetime import hashlib -import logging import uuid from typing import Any import asyncpg -import dateutil.parser +import fastapi_structured_logging import httpx from pydantic import BaseModel @@ -17,7 +16,7 @@ from db import dbh from gitlab_model import MergeRequestPayload -logger = logging.getLogger(__name__) +logger = fastapi_structured_logging.get_logger() class MRMessRef(BaseModel): @@ -34,43 +33,54 @@ async def get_or_create_message_refs( connection: asyncpg.Connection async with await database.acquire() as connection: - resset = await connection.fetch( - """ - SELECT - merge_request_message_ref_id, - conversation_token, - message_id - FROM - merge_request_message_ref - WHERE - merge_request_ref_id = $1 - -- AND conversation_token = ANY($2::uuid[]) - """, - merge_request_ref_id, - # conv_tokens, - ) - - for row in resset: - convtoken_to_msgrefs[str(row["conversation_token"])] = MRMessRef(**row) + async with connection.transaction(): + # Lock the MR ref to prevent concurrent modifications + await connection.execute( + """SELECT 1 FROM merge_request_ref + WHERE merge_request_ref_id = $1 + FOR UPDATE""", + merge_request_ref_id, + ) - for conv_token in conv_tokens: - if conv_token in convtoken_to_msgrefs: - continue - row = await connection.fetchrow( + resset = await connection.fetch( """ - INSERT INTO merge_request_message_ref ( - merge_request_ref_id, conversation_token - ) VALUES ( - $1, $2 - ) RETURNING merge_request_message_ref_id, - conversation_token, - message_id + SELECT + merge_request_message_ref_id, + conversation_token, + message_id + FROM + merge_request_message_ref + WHERE + merge_request_ref_id = $1 + -- AND conversation_token = ANY($2::uuid[]) """, merge_request_ref_id, - conv_token, + # conv_tokens, ) - assert row is not None - convtoken_to_msgrefs[str(row["conversation_token"])] = MRMessRef(**row) + + for row in resset: + convtoken_to_msgrefs[str(row["conversation_token"])] = MRMessRef(**row) + + for conv_token in conv_tokens: + if conv_token in convtoken_to_msgrefs: + continue + row = await connection.fetchrow( + """ + INSERT INTO merge_request_message_ref ( + merge_request_ref_id, conversation_token + ) VALUES ( + $1, $2 + ) ON CONFLICT (merge_request_ref_id, conversation_token) DO UPDATE + SET merge_request_ref_id = EXCLUDED.merge_request_ref_id + RETURNING merge_request_message_ref_id, + conversation_token, + message_id + """, + merge_request_ref_id, + conv_token, + ) + assert row is not None + convtoken_to_msgrefs[str(row["conversation_token"])] = MRMessRef(**row) return convtoken_to_msgrefs @@ -98,12 +108,24 @@ async def create_or_update_message( return None payload["conversation_token"] = str(mrmsgref.conversation_token) - res = await client.request( - "POST", - config.ACTIVITY_API + "api/v1/message", - json=payload, - ) - response = res.json() + try: + res = await client.request( + "POST", + config.ACTIVITY_API + "api/v1/message", + json=payload, + ) + res.raise_for_status() + response = res.json() + except Exception: + logger.error( + "failed to create message", + method="POST", + url=config.ACTIVITY_API + "api/v1/message", + conversation_token=str(mrmsgref.conversation_token), + status_code=res.status_code if "res" in locals() else None, + exc_info=True, + ) + raise connection: asyncpg.Connection async with await database.acquire() as connection: @@ -118,23 +140,36 @@ async def create_or_update_message( mrmsgref.merge_request_message_ref_id, ) if result is None or len(result) == 0: - # This case is a race condition so cleanup the second message :) - await client.request( - "DELETE", - config.ACTIVITY_API + "api/v1/message", - json={ - "message_id": str(response.get("message_id")), - }, - ) + try: + await client.request( + "DELETE", + config.ACTIVITY_API + "api/v1/message", + json={ + "message_id": str(response.get("message_id")), + }, + ) + except Exception: + logger.exception("Failed to delete duplicate message %s", response.get("message_id")) else: payload["message_id"] = str(mrmsgref.message_id) - res = await client.request( - "PATCH", - config.ACTIVITY_API + "api/v1/message", - json=payload, - ) - response = res.json() - res.raise_for_status() + try: + res = await client.request( + "PATCH", + config.ACTIVITY_API + "api/v1/message", + json=payload, + ) + res.raise_for_status() + response = res.json() + except Exception: + logger.error( + "failed to update message", + method="PATCH", + url=config.ACTIVITY_API + "api/v1/message", + message_id=str(mrmsgref.message_id), + status_code=res.status_code if "res" in locals() else None, + exc_info=True, + ) + raise return uuid.UUID(response.get("message_id")) @@ -146,8 +181,16 @@ async def merge_request( ): payload_fingerprint = hashlib.sha256(mr.model_dump_json().encode("utf8")).hexdigest() logger.debug("payload fingerprint: %s", payload_fingerprint) + + is_closing_action = mr.object_attributes.action in ("merge", "close") or mr.object_attributes.state in ( + "closed", + "merged", + ) + mri = await dbh.get_merge_request_ref_infos(mr) + need_cleanup_reschedule = False + participant_found = True if participant_ids_filter: participant_found = False @@ -163,7 +206,6 @@ async def merge_request( connection: asyncpg.Connection if mr.object_attributes.action in ("update"): - # Update MR info (head_pipeline_id) async with await database.acquire() as connection: row = await connection.fetchrow( """UPDATE merge_request_ref @@ -177,8 +219,6 @@ async def merge_request( if row is not None: mri.merge_request_extra_state = row["merge_request_extra_state"] - # If update and oldrev field is set => new commit in MR - # Approvals must be reset if new_commits_revoke_approvals and mr.object_attributes.oldrev: row = await connection.fetchrow( """UPDATE merge_request_ref @@ -193,41 +233,109 @@ async def merge_request( if row is not None: mri.merge_request_extra_state = row["merge_request_extra_state"] - # if it's a transition from draft to ready - # - Delete all messages related to this MR prior to the current event update - # then create_or_update_message will re-post new message - # to have cards being the most recent in the feeds. - # Rows are used as lock to avoid race condition when multiple instances can receive hooks - # for the same MR (multiple webhook same project, multiple instances [kube?]) if mr.changes and "draft" in mr.changes and not mr.object_attributes.draft: - assert mr.object_attributes.updated_at is not None - update_ref_datetime = dateutil.parser.parse(mr.object_attributes.updated_at) + logger.info( + "draft to ready transition detected - locking and updating all messages", + mr_ref_id=mri.merge_request_ref_id, + fingerprint=payload_fingerprint, + ) + message_expiration = datetime.timedelta(seconds=0) + timeout = httpx.Timeout(10.0, connect=5.0) + async with connection.transaction(): - res = await connection.fetch( - """SELECT merge_request_message_ref_id, message_id + locked_messages = await connection.fetch( + """SELECT merge_request_message_ref_id, conversation_token, message_id FROM merge_request_message_ref - WHERE merge_request_ref_id = $1 AND created_at < $2 + WHERE merge_request_ref_id = $1 FOR UPDATE""", mri.merge_request_ref_id, - update_ref_datetime, ) - for row in res: - message_id = row.get("message_id") - if message_id is not None: + + if len(locked_messages) > 0: + logger.info( + "locked messages for draft-to-ready update", + mr_ref_id=mri.merge_request_ref_id, + message_count=len(locked_messages), + ) + + temp_mri = await dbh.get_merge_request_ref_infos(mr) + temp_card = render( + temp_mri, + collapsed=False, + show_collapsible=False, + ) + temp_summary = ( + f"MR {temp_mri.merge_request_payload.object_attributes.state}:" + f" {temp_mri.merge_request_payload.object_attributes.title}\n" + f"on {temp_mri.merge_request_payload.project.path_with_namespace}" + ) + + async with httpx.AsyncClient(timeout=timeout) as client: + for row in locked_messages: + message_id = row.get("message_id") + if message_id is not None: + try: + payload = { + "message_id": str(message_id), + "card": temp_card, + } + if temp_summary: + payload["summary"] = temp_summary + + res = await client.request( + "PATCH", + config.ACTIVITY_API + "api/v1/message", + json=payload, + ) + res.raise_for_status() + logger.debug( + "updated message for draft-to-ready", + message_id=str(message_id), + mr_ref_id=mri.merge_request_ref_id, + ) + except Exception: + logger.error( + "failed to update message during draft-to-ready", + method="PATCH", + url=config.ACTIVITY_API + "api/v1/message", + message_id=str(message_id), + mr_ref_id=mri.merge_request_ref_id, + status_code=res.status_code if "res" in locals() else None, + exc_info=True, + ) + + for row in locked_messages: + message_id = row.get("message_id") + if message_id is not None: + await connection.execute( + """INSERT INTO msg_to_delete + (message_id, expire_at) + VALUES + ($1, now()+$2::INTERVAL)""", + str(message_id), + message_expiration, + ) await connection.execute( - """INSERT INTO msg_to_delete - (message_id, expire_at) - VALUES - ($1, now()+$2::INTERVAL)""", - str(message_id), - message_expiration, + """DELETE FROM merge_request_message_ref + WHERE merge_request_message_ref_id = $1""", + row.get("merge_request_message_ref_id"), ) + await connection.execute( - "DELETE FROM merge_request_message_ref WHERE merge_request_message_ref_id = $1", - row.get("merge_request_message_ref_id"), + """INSERT INTO webhook_fingerprint (fingerprint, processed_at) + VALUES ($1, now()) + ON CONFLICT (fingerprint) DO NOTHING""", + payload_fingerprint, + ) + + logger.info( + "marked all draft messages for deletion", + mr_ref_id=mri.merge_request_ref_id, + message_count=len(locked_messages), + fingerprint=payload_fingerprint, ) - periodic_cleanup.reschedule() + need_cleanup_reschedule = True if mr.object_attributes.action in ("approved", "unapproved"): v = mr.user.model_dump() @@ -265,66 +373,137 @@ async def merge_request( f"on {mri.merge_request_payload.project.path_with_namespace}" ) - convtoken_to_msgrefs = await get_or_create_message_refs( - mri.merge_request_ref_id, - conversation_tokens, - ) + if is_closing_action: + logger.info( + "close/merge action detected - locking and updating all messages", + mr_ref_id=mri.merge_request_ref_id, + action=mr.object_attributes.action, + state=mr.object_attributes.state, + fingerprint=payload_fingerprint, + ) + + message_expiration = datetime.timedelta(seconds=config.MESSAGE_DELETE_DELAY_SECONDS) + timeout = httpx.Timeout(10.0, connect=5.0) - if mr.object_attributes.action in ("open", "reopen") or True: - async with httpx.AsyncClient() as client: + async with await database.acquire() as connection: + async with connection.transaction(): + locked_messages = await connection.fetch( + """SELECT merge_request_message_ref_id, conversation_token, message_id + FROM merge_request_message_ref + WHERE merge_request_ref_id = $1 + FOR UPDATE""", + mri.merge_request_ref_id, + ) + + if len(locked_messages) > 0: + logger.info( + "locked messages for update", + mr_ref_id=mri.merge_request_ref_id, + message_count=len(locked_messages), + ) + + async with httpx.AsyncClient(timeout=timeout) as client: + for row in locked_messages: + message_id = row.get("message_id") + if message_id is not None: + try: + payload = { + "message_id": str(message_id), + "card": card, + } + if summary: + payload["summary"] = summary + + res = await client.request( + "PATCH", + config.ACTIVITY_API + "api/v1/message", + json=payload, + ) + res.raise_for_status() + logger.debug( + "updated message", + message_id=str(message_id), + mr_ref_id=mri.merge_request_ref_id, + ) + except Exception: + logger.error( + "failed to update message during close/merge", + method="PATCH", + url=config.ACTIVITY_API + "api/v1/message", + message_id=str(message_id), + mr_ref_id=mri.merge_request_ref_id, + status_code=res.status_code if "res" in locals() else None, + exc_info=True, + ) + + for row in locked_messages: + message_id = row.get("message_id") + if message_id is not None: + await connection.execute( + """INSERT INTO msg_to_delete + (message_id, expire_at) + VALUES + ($1, now()+$2::INTERVAL)""", + str(message_id), + message_expiration, + ) + await connection.execute( + "DELETE FROM merge_request_message_ref WHERE merge_request_message_ref_id = $1", + row.get("merge_request_message_ref_id"), + ) + + await connection.execute( + """INSERT INTO webhook_fingerprint (fingerprint, processed_at) + VALUES ($1, now()) + ON CONFLICT (fingerprint) DO NOTHING""", + payload_fingerprint, + ) + + logger.info( + "marked all messages for deletion", + mr_ref_id=mri.merge_request_ref_id, + message_count=len(locked_messages), + fingerprint=payload_fingerprint, + ) + need_cleanup_reschedule = True + else: + convtoken_to_msgrefs = await get_or_create_message_refs( + mri.merge_request_ref_id, + conversation_tokens, + ) + + timeout = httpx.Timeout(10.0, connect=5.0) + async with httpx.AsyncClient(timeout=timeout) as client: for ct in conversation_tokens: mrmsgref = convtoken_to_msgrefs[ct] + original_message_id = mrmsgref.message_id + mrmsgref.message_id = await create_or_update_message( client, mrmsgref, card=card, summary=summary, - update_only=mr.object_attributes.state - in ( - "closed", - "merged", - ) - or mr.object_attributes.draft - or mr.object_attributes.work_in_progress - or not participant_found, + update_only=( + ( + mr.object_attributes.action not in ("open", "reopen") + and (mr.object_attributes.draft or mr.object_attributes.work_in_progress) + ) + or not participant_found + ), ) - if mr.object_attributes.action in ( - "merge", - "close", - ) or mr.object_attributes.state in ( - "closed", - "merged", - ): - message_expiration = datetime.timedelta(seconds=30) - async with await database.acquire() as connection: - res = await connection.fetch( - """SELECT merge_request_message_ref_id, message_id - FROM merge_request_message_ref - WHERE merge_request_ref_id = $1""", - mri.merge_request_ref_id, - ) - for row in res: - message_id = row.get("message_id") - if message_id is not None: - await connection.execute( - """INSERT INTO msg_to_delete - (message_id, expire_at) - VALUES - ($1, now()+$2::INTERVAL)""", - str(message_id), - message_expiration, - ) - await connection.execute( - "DELETE FROM merge_request_message_ref WHERE merge_request_message_ref_id = $1", - row.get("merge_request_message_ref_id"), - ) - if len(res): - await connection.execute( - "DELETE FROM merge_request_ref WHERE merge_request_ref_id = $1", - mri.merge_request_ref_id, - ) - periodic_cleanup.reschedule() + if original_message_id is None and mrmsgref.message_id is not None: + async with await database.acquire() as conn: + await conn.execute( + """UPDATE merge_request_message_ref + SET message_id = $1 + WHERE merge_request_message_ref_id = $2""", + mrmsgref.message_id, + mrmsgref.merge_request_message_ref_id, + ) + + if need_cleanup_reschedule: + periodic_cleanup.reschedule() return { "merge_request_infos": mri, diff --git a/webhook/pipeline.py b/webhook/pipeline.py index 37627b6..56222cd 100644 --- a/webhook/pipeline.py +++ b/webhook/pipeline.py @@ -42,7 +42,8 @@ async def update_message(mri: MergeRequestInfos, conversation_tokens: list[str]) ) connection: asyncpg.Connection - async with await database.acquire() as connection, httpx.AsyncClient() as client: + timeout = httpx.Timeout(10.0, connect=5.0) + async with await database.acquire() as connection, httpx.AsyncClient(timeout=timeout) as client: res = await connection.fetch( """ SELECT merge_request_message_ref_id, conversation_token, message_id