Skip to content

Commit

Permalink
worker: Rewrite MissedMessageWorker to not be lossy.
Browse files Browse the repository at this point in the history
Previously, we stored up to 2 minutes worth of email events in memory
before processing them. So, if the server were to go down we would lose
those events.

To fix this, we store the events in the database.

This is a prep change for allowing users to set custom grace period for
email notifications, since the bug noted above will aggravate with
longer grace periods.
  • Loading branch information
abhijeetbodas2001 committed Jul 8, 2021
1 parent 81a91a9 commit 6f19600
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 49 deletions.
116 changes: 96 additions & 20 deletions zerver/tests/test_queue_worker.py
@@ -1,4 +1,5 @@
import base64
import datetime
import os
import signal
import time
Expand All @@ -19,7 +20,14 @@
from zerver.lib.send_email import EmailNotDeliveredException, FromAddress
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import mock_queue_publish, simulated_queue_client
from zerver.models import PreregistrationUser, UserActivity, get_client, get_realm, get_stream
from zerver.models import (
MissedMessageEmailEntry,
PreregistrationUser,
UserActivity,
get_client,
get_realm,
get_stream,
)
from zerver.tornado.event_queue import build_offline_notification
from zerver.worker import queue_processors
from zerver.worker.queue_processors import (
Expand Down Expand Up @@ -144,17 +152,27 @@ def test_missed_message_worker(self) -> None:
content="where art thou, othello?",
)

events = [
dict(user_profile_id=hamlet.id, message_id=hamlet1_msg_id),
dict(user_profile_id=hamlet.id, message_id=hamlet2_msg_id),
dict(user_profile_id=othello.id, message_id=othello_msg_id),
]
hamlet_event1 = dict(
user_profile_id=hamlet.id, message_id=hamlet1_msg_id, trigger="private_message"
)
hamlet_event2 = dict(
user_profile_id=hamlet.id,
message_id=hamlet2_msg_id,
trigger="private_message",
mentioned_user_group_id=4,
)
othello_event = dict(
user_profile_id=othello.id, message_id=othello_msg_id, trigger="private_message"
)

events = [hamlet_event1, hamlet_event2, othello_event]

fake_client = self.FakeClient()
for event in events:
fake_client.enqueue("missedmessage_emails", event)

mmw = MissedMessageWorker()
batch_duration = datetime.timedelta(seconds=mmw.BATCH_DURATION)

class MockTimer:
is_running = False
Expand All @@ -174,36 +192,94 @@ def start(self) -> None:
send_mock = patch(
"zerver.lib.email_notifications.do_send_missedmessage_events_reply_in_zulip",
)
mmw.BATCH_DURATION = 0

bonus_event = dict(user_profile_id=hamlet.id, message_id=hamlet3_msg_id)
bonus_event_hamlet = dict(
user_profile_id=hamlet.id, message_id=hamlet3_msg_id, trigger="private_message"
)

def check_row(
row: MissedMessageEmailEntry,
expiry_time: datetime.datetime,
mentioned_user_group_id: Optional[int],
) -> None:
self.assertEqual(row.trigger, "private_message")
self.assertEqual(row.expiry_time, expiry_time)
self.assertEqual(row.mentioned_user_group_id, mentioned_user_group_id)

with send_mock as sm, timer_mock as tm:
with simulated_queue_client(lambda: fake_client):
self.assertFalse(timer.is_alive())
mmw.setup()
mmw.start()

time_zero = datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc)
with patch("zerver.worker.queue_processors.timezone_now", return_value=time_zero):
mmw.setup()
mmw.start()
self.assertTrue(timer.is_alive())
fake_client.enqueue("missedmessage_emails", bonus_event)

expected_expiry_time = time_zero + batch_duration

# The events should be saved in the database
hamlet_row1 = MissedMessageEmailEntry.objects.get(
user_profile_id=hamlet.id, message_id=hamlet1_msg_id
)
check_row(hamlet_row1, expected_expiry_time, None)

hamlet_row2 = MissedMessageEmailEntry.objects.get(
user_profile_id=hamlet.id, message_id=hamlet2_msg_id
)
check_row(hamlet_row2, expected_expiry_time, 4)

othello_row1 = MissedMessageEmailEntry.objects.get(
user_profile_id=othello.id, message_id=othello_msg_id
)
check_row(othello_row1, expected_expiry_time, None)

fake_client.enqueue("missedmessage_emails", bonus_event_hamlet)

# Double-calling start is our way to get it to run again
self.assertTrue(timer.is_alive())
mmw.start()
with self.assertLogs(level="INFO") as info_logs:
# Now, we actually send the emails.
with patch("zerver.worker.queue_processors.timezone_now", return_value=time_zero):
mmw.start()

# Check that Hamlet got a new row
hamlet_row3 = MissedMessageEmailEntry.objects.get(
user_profile_id=hamlet.id, message_id=hamlet3_msg_id
)
check_row(hamlet_row3, expected_expiry_time, None)

# If `maybe_send_batched_emails` is called too early, it shouldn't process batches.
one_minute_premature = expected_expiry_time - datetime.timedelta(seconds=60)
with patch(
"zerver.worker.queue_processors.timezone_now", return_value=one_minute_premature
):
mmw.maybe_send_batched_emails()
self.assertEqual(
self.assertEqual(MissedMessageEmailEntry.objects.count(), 4)

# This should process all batches.
one_minute_overdue = expected_expiry_time + datetime.timedelta(seconds=60)
with self.assertLogs(level="INFO") as info_logs, patch(
"zerver.worker.queue_processors.timezone_now", return_value=one_minute_overdue
):
mmw.maybe_send_batched_emails()
self.assertEqual(MissedMessageEmailEntry.objects.count(), 0)

self.assert_length(info_logs.output, 2)
self.assertIn(
"INFO:root:Batch-processing 3 missedmessage_emails events for user 10",
info_logs.output,
)
self.assertIn(
"INFO:root:Batch-processing 1 missedmessage_emails events for user 12",
info_logs.output,
[
"INFO:root:Batch-processing 3 missedmessage_emails events for user 10",
"INFO:root:Batch-processing 1 missedmessage_emails events for user 12",
],
)

# All batches got processed. Test that the timer isn't running.
self.assertEqual(mmw.timer_event, None)

self.assertEqual(tm.call_args[0][0], 5) # should sleep 5 seconds
# Check that the frequency of calling maybe_send_batched_emails is correct (5 seconds)
self.assertEqual(tm.call_args[0][0], 5)

# Verify the payloads now
args = [c[0] for c in sm.call_args_list]
arg_dict = {
arg[0].id: dict(
Expand Down
77 changes: 48 additions & 29 deletions zerver/worker/queue_processors.py
Expand Up @@ -13,7 +13,7 @@
import time
import urllib
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections import deque
from email.message import EmailMessage
from functools import wraps
from threading import Lock, Timer
Expand All @@ -37,7 +37,7 @@
import sentry_sdk
from django.conf import settings
from django.core.mail.backends.smtp import EmailBackend
from django.db import connection
from django.db import connection, transaction
from django.db.models import F
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _
Expand Down Expand Up @@ -90,6 +90,7 @@
from zerver.lib.url_preview import preview as url_preview
from zerver.models import (
Message,
MissedMessageEmailEntry,
PreregistrationUser,
Realm,
RealmAuditLog,
Expand Down Expand Up @@ -546,17 +547,9 @@ class MissedMessageWorker(QueueProcessingWorker):
#
# The timer is running whenever; we poll at most every TIMER_FREQUENCY
# seconds, to avoid excessive activity.
#
# TODO: Since this process keeps events in memory for up to 2
# minutes, it now will lose approximately BATCH_DURATION worth of
# missed_message emails whenever it is restarted as part of a
# server restart. We should probably add some sort of save/reload
# mechanism for that case.
TIMER_FREQUENCY = 5
BATCH_DURATION = 120
timer_event: Optional[Timer] = None
events_by_recipient: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
batch_start_by_recipient: Dict[int, float] = {}

# This lock protects access to all of the data structures declared
# above. A lock is required because maybe_send_batched_emails, as
Expand All @@ -575,11 +568,18 @@ def consume(self, event: Dict[str, Any]) -> None:
logging.debug("Received missedmessage_emails event: %s", event)

# When we process an event, just put it into the queue and ensure we have a timer going.
user_profile_id = event["user_profile_id"]
if user_profile_id not in self.batch_start_by_recipient:
self.batch_start_by_recipient[user_profile_id] = time.time()
self.events_by_recipient[user_profile_id].append(event)

user_profile_id: int = event["user_profile_id"]
batch_duration = datetime.timedelta(seconds=self.BATCH_DURATION)

entry = MissedMessageEmailEntry(
user_profile_id=user_profile_id,
message_id=event["message_id"],
trigger=event["trigger"],
expiry_time=timezone_now() + batch_duration,
)
if "mentioned_user_group_id" in event:
entry.mentioned_user_group_id = event["mentioned_user_group_id"]
entry.save()
self.ensure_timer()

def ensure_timer(self) -> None:
Expand All @@ -600,25 +600,44 @@ def maybe_send_batched_emails(self) -> None:
# is active.
self.timer_event = None

current_time = time.time()
for user_profile_id, timestamp in list(self.batch_start_by_recipient.items()):
if current_time - timestamp < self.BATCH_DURATION:
continue
events = self.events_by_recipient[user_profile_id]
logging.info(
"Batch-processing %s missedmessage_emails events for user %s",
len(events),
user_profile_id,
)
handle_missedmessage_emails(user_profile_id, events)
del self.events_by_recipient[user_profile_id]
del self.batch_start_by_recipient[user_profile_id]
current_time = timezone_now()

with transaction.atomic():
events_to_process = MissedMessageEmailEntry.objects.filter(
expiry_time__lte=current_time
).select_related()

# Batch the entries by user
events_by_recipient: Dict[int, List[Dict[str, Any]]] = {}
for event in events_to_process:
entry = dict(
user_profile_id=event.user_profile_id,
message_id=event.message_id,
trigger=event.trigger,
mentioned_user_group_id=event.mentioned_user_group_id,
)
if event.user_profile_id in events_by_recipient:
events_by_recipient[event.user_profile_id].append(entry)
else:
events_by_recipient[event.user_profile_id] = [entry]

for user_profile_id in events_by_recipient.keys():
events: List[Dict[str, Any]] = events_by_recipient[user_profile_id]

logging.info(
"Batch-processing %s missedmessage_emails events for user %s",
len(events),
user_profile_id,
)
handle_missedmessage_emails(user_profile_id, events)

events_to_process.delete()

# By only restarting the timer if there are actually events in
# the queue, we ensure this queue processor is idle when there
# are no missed-message emails to process. This avoids
# constant CPU usage when there is no work to do.
if len(self.batch_start_by_recipient) > 0:
if MissedMessageEmailEntry.objects.exists():
self.ensure_timer()


Expand Down

0 comments on commit 6f19600

Please sign in to comment.