From 47a15c22af0e494dbb35163342b90b928dda1b1f Mon Sep 17 00:00:00 2001 From: Sandro Date: Wed, 20 May 2026 13:10:31 +0100 Subject: [PATCH] refactor(notifications): Plumb token + cost usage into notification context Extends the context dicts produced by _render_payload (single job / schedule) and _render_batch_payload (multi-job rollup) with token and cost data sourced from Activity.input_tokens / output_tokens / total_tokens / cost_usd. For batches the four values are aggregated via SUM across siblings, and rows now also carry per-activity status so the batch context can include a `repo_results` list (one {repo, ok} entry per sibling) for renderers to show per-repo outcome. Bell, email, and in-app behavior are unchanged: existing consumers ignore the new keys, and tests covering the established context fields still pass. Groundwork for the upcoming Rocket Chat renderer registry, which will surface usage as channel fields. --- daiv/notifications/signals.py | 34 +++++-- .../unit_tests/notifications/test_signals.py | 95 +++++++++++++++++++ 2 files changed, 123 insertions(+), 6 deletions(-) diff --git a/daiv/notifications/signals.py b/daiv/notifications/signals.py index 1db2ce8c..5548227c 100644 --- a/daiv/notifications/signals.py +++ b/daiv/notifications/signals.py @@ -5,7 +5,7 @@ from django.conf import settings from django.db import Error as DatabaseError from django.db import IntegrityError -from django.db.models import Count, Q +from django.db.models import Count, Q, Sum from django.db.models.signals import post_save from django.dispatch import receiver from django.urls import reverse @@ -93,6 +93,10 @@ def _render_payload(activity: Activity) -> tuple[str, str, dict]: "trigger_owner": owner, "repo_id": repo, "duration_seconds": activity.duration, + "input_tokens": activity.input_tokens, + "output_tokens": activity.output_tokens, + "total_tokens": activity.total_tokens, + "cost_usd": float(activity.cost_usd) if activity.cost_usd is not None else None, } return subject, body, context @@ -155,6 +159,10 @@ def _handle_batch_completion(activity: Activity, siblings, total: int) -> None: agg = siblings.aggregate( terminal=Count("id", filter=Q(status__in=ActivityStatus.terminal())), successful=Count("id", filter=Q(status=ActivityStatus.SUCCESSFUL)), + total_input_tokens=Sum("input_tokens"), + total_output_tokens=Sum("output_tokens"), + total_total_tokens=Sum("total_tokens"), + total_cost_usd=Sum("cost_usd"), ) if agg["terminal"] < total: return @@ -175,12 +183,18 @@ def _handle_batch_completion(activity: Activity, siblings, total: int) -> None: failed = total - successful agg_status = ActivityStatus.SUCCESSFUL if failed == 0 else ActivityStatus.FAILED - rows = list(siblings.values_list("repo_id", "started_at", "finished_at")) + rows = list(siblings.values_list("repo_id", "started_at", "finished_at", "status")) effective = activity.effective_notify_on channels = [cls.channel_type for cls in enabled_channels()] if _status_matches(effective, agg_status) else [] - subject, body, context = _render_batch_payload(activity, rows, total, successful, failed, agg_status) + usage = { + "input_tokens": agg["total_input_tokens"], + "output_tokens": agg["total_output_tokens"], + "total_tokens": agg["total_total_tokens"], + "cost_usd": float(agg["total_cost_usd"]) if agg["total_cost_usd"] is not None else None, + } + subject, body, context = _render_batch_payload(activity, rows, total, successful, failed, agg_status, usage) link_url = f"{reverse('activity_list')}?batch={activity.batch_id}" for recipient in recipients.values(): @@ -232,11 +246,14 @@ def _rollup_exists(recipient, batch_id) -> bool: def _render_batch_payload( - activity: Activity, rows: list[tuple], total: int, successful: int, failed: int, agg_status: str + activity: Activity, rows: list[tuple], total: int, successful: int, failed: int, agg_status: str, usage: dict ) -> tuple[str, str, dict]: is_schedule = _is_schedule(activity) ok = failed == 0 - repo_ids = sorted({repo for repo, _start, _end in rows if repo}) + repo_ids = sorted({repo for repo, _start, _end, _status in rows if repo}) + repo_results = [ + {"repo": repo, "ok": status == ActivityStatus.SUCCESSFUL} for repo, _start, _end, status in rows if repo + ] name = activity.scheduled_job.name if is_schedule else "" owner = str(activity.scheduled_job.user) if is_schedule else "" @@ -277,11 +294,16 @@ def _render_batch_payload( "trigger_owner": owner, "repo_id": repo_ids[0] if len(repo_ids) == 1 else "", "repo_ids": repo_ids, + "repo_results": repo_results, "total": total, "successful_count": successful, "failed_count": failed, "duration_seconds": _batch_duration(rows), "batch_id": str(activity.batch_id), + "input_tokens": usage["input_tokens"], + "output_tokens": usage["output_tokens"], + "total_tokens": usage["total_tokens"], + "cost_usd": usage["cost_usd"], } return subject, body, context @@ -297,7 +319,7 @@ def _summarize_repos(repo_ids: list[str], limit: int = 3) -> str: def _batch_duration(rows: list[tuple]) -> float | None: """Wall-clock span from earliest start to latest finish across the batch.""" - pairs = [(start, end) for _repo, start, end in rows if start and end] + pairs = [(start, end) for _repo, start, end, _status in rows if start and end] if not pairs: return None earliest = min(start for start, _end in pairs) diff --git a/tests/unit_tests/notifications/test_signals.py b/tests/unit_tests/notifications/test_signals.py index 96b282e2..28e23086 100644 --- a/tests/unit_tests/notifications/test_signals.py +++ b/tests/unit_tests/notifications/test_signals.py @@ -1,5 +1,6 @@ import logging import uuid +from decimal import Decimal from unittest.mock import patch from django.utils import timezone @@ -420,6 +421,44 @@ def test_job_rendered_subject_and_event_type(self, member_user): assert n.context["trigger_label"] assert n.context["repo_id"] == "acme/app" + def test_job_context_carries_token_and_cost_usage(self, member_user): + member_user.notify_on_jobs = NotifyOn.ALWAYS + member_user.save(update_fields=["notify_on_jobs"]) + + activity = Activity.objects.create( + trigger_type=TriggerType.UI_JOB, + user=member_user, + repo_id="acme/app", + status=ActivityStatus.SUCCESSFUL, + input_tokens=12345, + output_tokens=6789, + total_tokens=19134, + cost_usd=Decimal("0.214321"), + ) + activity_finished.send(sender=Activity, activity=activity) + + n = Notification.objects.get(recipient=member_user, event_type="job.finished") + assert n.context["input_tokens"] == 12345 + assert n.context["output_tokens"] == 6789 + assert n.context["total_tokens"] == 19134 + # JSONField round-trips Decimal as float; renderers receive a number, not a string. + assert n.context["cost_usd"] == pytest.approx(0.214321) + + def test_job_context_token_fields_default_to_none_when_unset(self, member_user): + member_user.notify_on_jobs = NotifyOn.ALWAYS + member_user.save(update_fields=["notify_on_jobs"]) + + activity = Activity.objects.create( + trigger_type=TriggerType.UI_JOB, user=member_user, repo_id="acme/app", status=ActivityStatus.SUCCESSFUL + ) + activity_finished.send(sender=Activity, activity=activity) + + n = Notification.objects.get(recipient=member_user, event_type="job.finished") + assert n.context["input_tokens"] is None + assert n.context["output_tokens"] is None + assert n.context["total_tokens"] is None + assert n.context["cost_usd"] is None + @pytest.mark.django_db class TestUserBindingSeeder: @@ -765,6 +804,62 @@ def test_schedule_batch_mixed_outcomes_renders_count_in_subject(self, member_use assert schedule.name in rollup.subject assert "2/3" in rollup.subject + def test_batch_context_carries_repo_results_and_summed_usage(self, member_user): + member_user.notify_on_jobs = NotifyOn.ALWAYS + member_user.save(update_fields=["notify_on_jobs"]) + + bid = uuid.uuid4() + a = Activity.objects.create( + trigger_type=TriggerType.API_JOB, + user=member_user, + repo_id="acme/api", + status=ActivityStatus.SUCCESSFUL, + batch_id=bid, + notify_on=NotifyOn.ALWAYS, + input_tokens=100, + output_tokens=200, + total_tokens=300, + cost_usd=Decimal("0.10"), + ) + b = Activity.objects.create( + trigger_type=TriggerType.API_JOB, + user=member_user, + repo_id="acme/legacy", + status=ActivityStatus.FAILED, + batch_id=bid, + notify_on=NotifyOn.ALWAYS, + input_tokens=50, + output_tokens=75, + total_tokens=125, + cost_usd=Decimal("0.05"), + ) + activity_finished.send(sender=Activity, activity=a) + activity_finished.send(sender=Activity, activity=b) + + rollup = Notification.objects.get(recipient=member_user, event_type="job_batch.finished") + assert rollup.context["input_tokens"] == 150 + assert rollup.context["output_tokens"] == 275 + assert rollup.context["total_tokens"] == 425 + assert rollup.context["cost_usd"] == pytest.approx(0.15) + + # repo_results preserves per-repo outcome so renderers can show ✓/✗ per row. + results_by_repo = {r["repo"]: r["ok"] for r in rollup.context["repo_results"]} + assert results_by_repo == {"acme/api": True, "acme/legacy": False} + + def test_batch_context_usage_totals_are_none_when_no_activity_has_usage(self, member_user): + member_user.notify_on_jobs = NotifyOn.ALWAYS + member_user.save(update_fields=["notify_on_jobs"]) + + a, b = self._make_batch(member_user, statuses=[ActivityStatus.SUCCESSFUL, ActivityStatus.SUCCESSFUL]) + activity_finished.send(sender=Activity, activity=a) + activity_finished.send(sender=Activity, activity=b) + + rollup = Notification.objects.get(event_type="job_batch.finished") + assert rollup.context["input_tokens"] is None + assert rollup.context["output_tokens"] is None + assert rollup.context["total_tokens"] is None + assert rollup.context["cost_usd"] is None + def test_empty_recipients_on_multi_job_batch_logs_warning(self, caplog): bid = uuid.uuid4() a = Activity.objects.create(