Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions daiv/notifications/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.conf import settings
from django.db import Error as DatabaseError
from django.db import IntegrityError
from django.db.models import Count, Q
from django.db.models import Count, Q, Sum
from django.db.models.signals import post_save
from django.dispatch import receiver
from django.urls import reverse
Expand Down Expand Up @@ -93,6 +93,10 @@ def _render_payload(activity: Activity) -> tuple[str, str, dict]:
"trigger_owner": owner,
"repo_id": repo,
"duration_seconds": activity.duration,
"input_tokens": activity.input_tokens,
"output_tokens": activity.output_tokens,
"total_tokens": activity.total_tokens,
"cost_usd": float(activity.cost_usd) if activity.cost_usd is not None else None,
}
return subject, body, context

Expand Down Expand Up @@ -155,6 +159,10 @@ def _handle_batch_completion(activity: Activity, siblings, total: int) -> None:
agg = siblings.aggregate(
terminal=Count("id", filter=Q(status__in=ActivityStatus.terminal())),
successful=Count("id", filter=Q(status=ActivityStatus.SUCCESSFUL)),
total_input_tokens=Sum("input_tokens"),
total_output_tokens=Sum("output_tokens"),
total_total_tokens=Sum("total_tokens"),
total_cost_usd=Sum("cost_usd"),
)
if agg["terminal"] < total:
return
Expand All @@ -175,12 +183,18 @@ def _handle_batch_completion(activity: Activity, siblings, total: int) -> None:
failed = total - successful
agg_status = ActivityStatus.SUCCESSFUL if failed == 0 else ActivityStatus.FAILED

rows = list(siblings.values_list("repo_id", "started_at", "finished_at"))
rows = list(siblings.values_list("repo_id", "started_at", "finished_at", "status"))

effective = activity.effective_notify_on
channels = [cls.channel_type for cls in enabled_channels()] if _status_matches(effective, agg_status) else []

subject, body, context = _render_batch_payload(activity, rows, total, successful, failed, agg_status)
usage = {
"input_tokens": agg["total_input_tokens"],
"output_tokens": agg["total_output_tokens"],
"total_tokens": agg["total_total_tokens"],
"cost_usd": float(agg["total_cost_usd"]) if agg["total_cost_usd"] is not None else None,
}
subject, body, context = _render_batch_payload(activity, rows, total, successful, failed, agg_status, usage)
link_url = f"{reverse('activity_list')}?batch={activity.batch_id}"

for recipient in recipients.values():
Expand Down Expand Up @@ -232,11 +246,14 @@ def _rollup_exists(recipient, batch_id) -> bool:


def _render_batch_payload(
activity: Activity, rows: list[tuple], total: int, successful: int, failed: int, agg_status: str
activity: Activity, rows: list[tuple], total: int, successful: int, failed: int, agg_status: str, usage: dict
) -> tuple[str, str, dict]:
is_schedule = _is_schedule(activity)
ok = failed == 0
repo_ids = sorted({repo for repo, _start, _end in rows if repo})
repo_ids = sorted({repo for repo, _start, _end, _status in rows if repo})
repo_results = [
{"repo": repo, "ok": status == ActivityStatus.SUCCESSFUL} for repo, _start, _end, status in rows if repo
]
name = activity.scheduled_job.name if is_schedule else ""
owner = str(activity.scheduled_job.user) if is_schedule else ""

Expand Down Expand Up @@ -277,11 +294,16 @@ def _render_batch_payload(
"trigger_owner": owner,
"repo_id": repo_ids[0] if len(repo_ids) == 1 else "",
"repo_ids": repo_ids,
"repo_results": repo_results,
"total": total,
"successful_count": successful,
"failed_count": failed,
"duration_seconds": _batch_duration(rows),
"batch_id": str(activity.batch_id),
"input_tokens": usage["input_tokens"],
"output_tokens": usage["output_tokens"],
"total_tokens": usage["total_tokens"],
"cost_usd": usage["cost_usd"],
}
return subject, body, context

Expand All @@ -297,7 +319,7 @@ def _summarize_repos(repo_ids: list[str], limit: int = 3) -> str:

def _batch_duration(rows: list[tuple]) -> float | None:
"""Wall-clock span from earliest start to latest finish across the batch."""
pairs = [(start, end) for _repo, start, end in rows if start and end]
pairs = [(start, end) for _repo, start, end, _status in rows if start and end]
if not pairs:
return None
earliest = min(start for start, _end in pairs)
Expand Down
95 changes: 95 additions & 0 deletions tests/unit_tests/notifications/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import uuid
from decimal import Decimal
from unittest.mock import patch

from django.utils import timezone
Expand Down Expand Up @@ -420,6 +421,44 @@ def test_job_rendered_subject_and_event_type(self, member_user):
assert n.context["trigger_label"]
assert n.context["repo_id"] == "acme/app"

def test_job_context_carries_token_and_cost_usage(self, member_user):
member_user.notify_on_jobs = NotifyOn.ALWAYS
member_user.save(update_fields=["notify_on_jobs"])

activity = Activity.objects.create(
trigger_type=TriggerType.UI_JOB,
user=member_user,
repo_id="acme/app",
status=ActivityStatus.SUCCESSFUL,
input_tokens=12345,
output_tokens=6789,
total_tokens=19134,
cost_usd=Decimal("0.214321"),
)
activity_finished.send(sender=Activity, activity=activity)

n = Notification.objects.get(recipient=member_user, event_type="job.finished")
assert n.context["input_tokens"] == 12345
assert n.context["output_tokens"] == 6789
assert n.context["total_tokens"] == 19134
# JSONField round-trips Decimal as float; renderers receive a number, not a string.
assert n.context["cost_usd"] == pytest.approx(0.214321)

def test_job_context_token_fields_default_to_none_when_unset(self, member_user):
member_user.notify_on_jobs = NotifyOn.ALWAYS
member_user.save(update_fields=["notify_on_jobs"])

activity = Activity.objects.create(
trigger_type=TriggerType.UI_JOB, user=member_user, repo_id="acme/app", status=ActivityStatus.SUCCESSFUL
)
activity_finished.send(sender=Activity, activity=activity)

n = Notification.objects.get(recipient=member_user, event_type="job.finished")
assert n.context["input_tokens"] is None
assert n.context["output_tokens"] is None
assert n.context["total_tokens"] is None
assert n.context["cost_usd"] is None


@pytest.mark.django_db
class TestUserBindingSeeder:
Expand Down Expand Up @@ -765,6 +804,62 @@ def test_schedule_batch_mixed_outcomes_renders_count_in_subject(self, member_use
assert schedule.name in rollup.subject
assert "2/3" in rollup.subject

def test_batch_context_carries_repo_results_and_summed_usage(self, member_user):
member_user.notify_on_jobs = NotifyOn.ALWAYS
member_user.save(update_fields=["notify_on_jobs"])

bid = uuid.uuid4()
a = Activity.objects.create(
trigger_type=TriggerType.API_JOB,
user=member_user,
repo_id="acme/api",
status=ActivityStatus.SUCCESSFUL,
batch_id=bid,
notify_on=NotifyOn.ALWAYS,
input_tokens=100,
output_tokens=200,
total_tokens=300,
cost_usd=Decimal("0.10"),
)
b = Activity.objects.create(
trigger_type=TriggerType.API_JOB,
user=member_user,
repo_id="acme/legacy",
status=ActivityStatus.FAILED,
batch_id=bid,
notify_on=NotifyOn.ALWAYS,
input_tokens=50,
output_tokens=75,
total_tokens=125,
cost_usd=Decimal("0.05"),
)
activity_finished.send(sender=Activity, activity=a)
activity_finished.send(sender=Activity, activity=b)

rollup = Notification.objects.get(recipient=member_user, event_type="job_batch.finished")
assert rollup.context["input_tokens"] == 150
assert rollup.context["output_tokens"] == 275
assert rollup.context["total_tokens"] == 425
assert rollup.context["cost_usd"] == pytest.approx(0.15)

# repo_results preserves per-repo outcome so renderers can show ✓/✗ per row.
results_by_repo = {r["repo"]: r["ok"] for r in rollup.context["repo_results"]}
assert results_by_repo == {"acme/api": True, "acme/legacy": False}

def test_batch_context_usage_totals_are_none_when_no_activity_has_usage(self, member_user):
member_user.notify_on_jobs = NotifyOn.ALWAYS
member_user.save(update_fields=["notify_on_jobs"])

a, b = self._make_batch(member_user, statuses=[ActivityStatus.SUCCESSFUL, ActivityStatus.SUCCESSFUL])
activity_finished.send(sender=Activity, activity=a)
activity_finished.send(sender=Activity, activity=b)

rollup = Notification.objects.get(event_type="job_batch.finished")
assert rollup.context["input_tokens"] is None
assert rollup.context["output_tokens"] is None
assert rollup.context["total_tokens"] is None
assert rollup.context["cost_usd"] is None

def test_empty_recipients_on_multi_job_batch_logs_warning(self, caplog):
bid = uuid.uuid4()
a = Activity.objects.create(
Expand Down