Skip to content

Commit

Permalink
Add a more generic system for handling early messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
greyson-signal committed Apr 21, 2020
1 parent 2afb939 commit 83f6640
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 27 deletions.
Expand Up @@ -231,7 +231,6 @@ public class MmsDatabase extends MessagingDatabase {
private static final String OUTGOING_SECURE_MESSAGES_CLAUSE = "(" + MESSAGE_BOX + " & " + Types.BASE_TYPE_MASK + ") = " + Types.BASE_SENT_TYPE + " AND (" + MESSAGE_BOX + " & " + (Types.SECURE_MESSAGE_BIT | Types.PUSH_MESSAGE_BIT) + ")";

private final EarlyReceiptCache earlyDeliveryReceiptCache = new EarlyReceiptCache("MmsDelivery");
private final EarlyReceiptCache earlyReadReceiptCache = new EarlyReceiptCache("MmsRead");

public MmsDatabase(Context context, SQLCipherOpenHelper databaseHelper) {
super(context, databaseHelper);
Expand Down Expand Up @@ -336,7 +335,7 @@ public void removeFailure(long messageId, NetworkFailure failure) {
}
}

public void incrementReceiptCount(SyncMessageId messageId, long timestamp, boolean deliveryReceipt, boolean readReceipt) {
public boolean incrementReceiptCount(SyncMessageId messageId, long timestamp, boolean deliveryReceipt) {
SQLiteDatabase database = databaseHelper.getWritableDatabase();
Cursor cursor = null;
boolean found = false;
Expand Down Expand Up @@ -368,10 +367,12 @@ public void incrementReceiptCount(SyncMessageId messageId, long timestamp, boole
}
}

if (!found) {
if (deliveryReceipt) earlyDeliveryReceiptCache.increment(messageId.getTimetamp(), messageId.getRecipientId());
if (readReceipt) earlyReadReceiptCache.increment(messageId.getTimetamp(), messageId.getRecipientId());
if (!found && deliveryReceipt) {
earlyDeliveryReceiptCache.increment(messageId.getTimetamp(), messageId.getRecipientId());
return true;
}

return found;
} finally {
if (cursor != null)
cursor.close();
Expand Down Expand Up @@ -1097,7 +1098,6 @@ public long insertMessageOutbox(@NonNull OutgoingMediaMessage message,
}

Map<RecipientId, Long> earlyDeliveryReceipts = earlyDeliveryReceiptCache.remove(message.getSentTimeMillis());
Map<RecipientId, Long> earlyReadReceipts = earlyReadReceiptCache.remove(message.getSentTimeMillis());

ContentValues contentValues = new ContentValues();
contentValues.put(DATE_SENT, message.getSentTimeMillis());
Expand All @@ -1112,7 +1112,6 @@ public long insertMessageOutbox(@NonNull OutgoingMediaMessage message,
contentValues.put(VIEW_ONCE, message.isViewOnce());
contentValues.put(RECIPIENT_ID, message.getRecipient().getId().serialize());
contentValues.put(DELIVERY_RECEIPT_COUNT, Stream.of(earlyDeliveryReceipts.values()).mapToLong(Long::longValue).sum());
contentValues.put(READ_RECEIPT_COUNT, Stream.of(earlyReadReceipts.values()).mapToLong(Long::longValue).sum());

List<Attachment> quoteAttachments = new LinkedList<>();

Expand Down Expand Up @@ -1148,7 +1147,6 @@ public long insertMessageOutbox(@NonNull OutgoingMediaMessage message,
receiptDatabase.insert(members, messageId, defaultReceiptStatus, message.getSentTimeMillis());

for (RecipientId recipientId : earlyDeliveryReceipts.keySet()) receiptDatabase.update(recipientId, messageId, GroupReceiptDatabase.STATUS_DELIVERED, -1);
for (RecipientId recipientId : earlyReadReceipts.keySet()) receiptDatabase.update(recipientId, messageId, GroupReceiptDatabase.STATUS_READ, -1);
}

DatabaseFactory.getThreadDatabase(context).setLastSeen(threadId);
Expand Down
Expand Up @@ -269,13 +269,17 @@ public long getThreadForMessageId(long messageId) {
}

public void incrementDeliveryReceiptCount(SyncMessageId syncMessageId, long timestamp) {
DatabaseFactory.getSmsDatabase(context).incrementReceiptCount(syncMessageId, true, false);
DatabaseFactory.getMmsDatabase(context).incrementReceiptCount(syncMessageId, timestamp, true, false);
DatabaseFactory.getSmsDatabase(context).incrementReceiptCount(syncMessageId, true);
DatabaseFactory.getMmsDatabase(context).incrementReceiptCount(syncMessageId, timestamp, true);
}

public void incrementReadReceiptCount(SyncMessageId syncMessageId, long timestamp) {
DatabaseFactory.getSmsDatabase(context).incrementReceiptCount(syncMessageId, false, true);
DatabaseFactory.getMmsDatabase(context).incrementReceiptCount(syncMessageId, timestamp, false, true);
public boolean incrementReadReceiptCount(SyncMessageId syncMessageId, long timestamp) {
boolean handled = false;

handled |= DatabaseFactory.getSmsDatabase(context).incrementReceiptCount(syncMessageId, false);
handled |= DatabaseFactory.getMmsDatabase(context).incrementReceiptCount(syncMessageId, timestamp, false);

return handled;
}

public int getQuotedMessagePosition(long threadId, long quoteId, @NonNull RecipientId recipientId) {
Expand Down
Expand Up @@ -134,7 +134,6 @@ public class SmsDatabase extends MessagingDatabase {
private final String OUTGOING_SECURE_MESSAGE_CLAUSE = "(" + TYPE + " & " + Types.BASE_TYPE_MASK + ") = " + Types.BASE_SENT_TYPE + " AND (" + TYPE + " & " + (Types.SECURE_MESSAGE_BIT | Types.PUSH_MESSAGE_BIT) + ")";

private static final EarlyReceiptCache earlyDeliveryReceiptCache = new EarlyReceiptCache("SmsDelivery");
private static final EarlyReceiptCache earlyReadReceiptCache = new EarlyReceiptCache("SmsRead");

public SmsDatabase(Context context, SQLCipherOpenHelper databaseHelper) {
super(context, databaseHelper);
Expand Down Expand Up @@ -395,7 +394,7 @@ public void markAsNotified(long id) {
database.update(TABLE_NAME, contentValues, ID_WHERE, new String[] {String.valueOf(id)});
}

public void incrementReceiptCount(SyncMessageId messageId, boolean deliveryReceipt, boolean readReceipt) {
public boolean incrementReceiptCount(SyncMessageId messageId, boolean deliveryReceipt) {
SQLiteDatabase database = databaseHelper.getWritableDatabase();
Cursor cursor = null;
boolean foundMessage = false;
Expand Down Expand Up @@ -426,11 +425,12 @@ public void incrementReceiptCount(SyncMessageId messageId, boolean deliveryRecei
}
}

if (!foundMessage) {
if (deliveryReceipt) earlyDeliveryReceiptCache.increment(messageId.getTimetamp(), messageId.getRecipientId());
if (readReceipt) earlyReadReceiptCache.increment(messageId.getTimetamp(), messageId.getRecipientId());
if (!foundMessage && deliveryReceipt) {
earlyDeliveryReceiptCache.increment(messageId.getTimetamp(), messageId.getRecipientId());
return true;
}

return foundMessage;
} finally {
if (cursor != null)
cursor.close();
Expand Down Expand Up @@ -721,7 +721,6 @@ public long insertMessageOutbox(long threadId, OutgoingTextMessage message,

RecipientId recipientId = message.getRecipient().getId();
Map<RecipientId, Long> earlyDeliveryReceipts = earlyDeliveryReceiptCache.remove(date);
Map<RecipientId, Long> earlyReadReceipts = earlyReadReceiptCache.remove(date);

ContentValues contentValues = new ContentValues(6);
contentValues.put(RECIPIENT_ID, recipientId.serialize());
Expand All @@ -734,7 +733,6 @@ public long insertMessageOutbox(long threadId, OutgoingTextMessage message,
contentValues.put(SUBSCRIPTION_ID, message.getSubscriptionId());
contentValues.put(EXPIRES_IN, message.getExpiresIn());
contentValues.put(DELIVERY_RECEIPT_COUNT, Stream.of(earlyDeliveryReceipts.values()).mapToLong(Long::longValue).sum());
contentValues.put(READ_RECEIPT_COUNT, Stream.of(earlyReadReceipts.values()).mapToLong(Long::longValue).sum());

SQLiteDatabase db = databaseHelper.getWritableDatabase();
long messageId = db.insert(TABLE_NAME, null, contentValues);
Expand Down
Expand Up @@ -13,6 +13,7 @@
import org.thoughtcrime.securesms.push.SignalServiceNetworkAccess;
import org.thoughtcrime.securesms.recipients.LiveRecipientCache;
import org.thoughtcrime.securesms.service.IncomingMessageObserver;
import org.thoughtcrime.securesms.util.EarlyMessageCache;
import org.thoughtcrime.securesms.util.FeatureFlags;
import org.thoughtcrime.securesms.util.FrameRateTracker;
import org.thoughtcrime.securesms.util.IasKeyStore;
Expand Down Expand Up @@ -47,6 +48,7 @@ public class ApplicationDependencies {
private static KeyValueStore keyValueStore;
private static MegaphoneRepository megaphoneRepository;
private static GroupsV2Operations groupsV2Operations;
private static EarlyMessageCache earlyMessageCache;

public static synchronized void init(@NonNull Application application, @NonNull Provider provider) {
if (ApplicationDependencies.application != null || ApplicationDependencies.provider != null) {
Expand Down Expand Up @@ -195,6 +197,16 @@ public static synchronized void resetSignalServiceMessageReceiver() {
return megaphoneRepository;
}

public static synchronized @NonNull EarlyMessageCache getEarlyMessageCache() {
assertInitialization();

if (earlyMessageCache == null) {
earlyMessageCache = provider.provideEarlyMessageCache();
}

return earlyMessageCache;
}

private static void assertInitialization() {
if (application == null || provider == null) {
throw new UninitializedException();
Expand All @@ -214,6 +226,7 @@ public interface Provider {
@NonNull FrameRateTracker provideFrameRateTracker();
@NonNull KeyValueStore provideKeyValueStore();
@NonNull MegaphoneRepository provideMegaphoneRepository();
@NonNull EarlyMessageCache provideEarlyMessageCache();
}

private static class UninitializedException extends IllegalStateException {
Expand Down
Expand Up @@ -25,6 +25,7 @@
import org.thoughtcrime.securesms.recipients.LiveRecipientCache;
import org.thoughtcrime.securesms.service.IncomingMessageObserver;
import org.thoughtcrime.securesms.util.AlarmSleepTimer;
import org.thoughtcrime.securesms.util.EarlyMessageCache;
import org.thoughtcrime.securesms.util.FeatureFlags;
import org.thoughtcrime.securesms.util.FrameRateTracker;
import org.thoughtcrime.securesms.util.TextSecurePreferences;
Expand Down Expand Up @@ -146,6 +147,11 @@ public ApplicationDependencyProvider(@NonNull Application context, @NonNull Sign
return new MegaphoneRepository(context);
}

@Override
public @NonNull EarlyMessageCache provideEarlyMessageCache() {
return new EarlyMessageCache();
}

private static class DynamicCredentialsProvider implements CredentialsProvider {

private final Context context;
Expand Down
Expand Up @@ -237,8 +237,8 @@ public void onRun() {
Optional<Long> optionalSmsMessageId = smsMessageId > 0 ? Optional.of(smsMessageId) : Optional.absent();

if (messageState == MessageState.DECRYPTED_OK) {
//noinspection ConstantConditions
handleMessage(serializedPlaintextContent, optionalSmsMessageId);
SignalServiceContent content = SignalServiceContent.deserialize(serializedPlaintextContent);
handleMessage(content, optionalSmsMessageId);
} else {
//noinspection ConstantConditions
handleExceptionMessage(exceptionMetadata, optionalSmsMessageId);
Expand All @@ -254,10 +254,9 @@ public boolean onShouldRetry(@NonNull Exception exception) {
public void onFailure() {
}

private void handleMessage(@NonNull byte[] plaintextDataBuffer, @NonNull Optional<Long> smsMessageId) {
private void handleMessage(@Nullable SignalServiceContent content, @NonNull Optional<Long> smsMessageId) {
try {
GroupDatabase groupDatabase = DatabaseFactory.getGroupDatabase(context);
SignalServiceContent content = SignalServiceContent.deserialize(plaintextDataBuffer);
GroupDatabase groupDatabase = DatabaseFactory.getGroupDatabase(context);

if (content == null || shouldIgnore(content)) {
Log.i(TAG, "Ignoring message.");
Expand Down Expand Up @@ -333,6 +332,13 @@ private void handleMessage(@NonNull byte[] plaintextDataBuffer, @NonNull Optiona

resetRecipientToPush(Recipient.externalPush(context, content.getSender()));

Optional<SignalServiceContent> earlyContent = ApplicationDependencies.getEarlyMessageCache()
.retrieve(Recipient.externalPush(context, content.getSender()).getId(),
content.getTimestamp());
if (earlyContent.isPresent()) {
Log.i(TAG, "Found dependent content that was retrieved earlier. Processing.");
handleMessage(earlyContent.get(), Optional.absent());
}
} catch (StorageFailedException e) {
Log.w(TAG, e);
handleCorruptMessage(e.getSender(), e.getSenderDevice(), timestamp, smsMessageId);
Expand Down Expand Up @@ -625,6 +631,7 @@ private void handleReaction(@NonNull SignalServiceContent content, @NonNull Sign
Log.w(TAG, "[handleReaction] Found a matching message, but it's flagged as remotely deleted. timestamp: " + reaction.getTargetSentTimestamp() + " author: " + targetAuthor.getId());
} else {
Log.w(TAG, "[handleReaction] Could not find matching message! timestamp: " + reaction.getTargetSentTimestamp() + " author: " + targetAuthor.getId());
ApplicationDependencies.getEarlyMessageCache().store(targetAuthor.getId(), reaction.getTargetSentTimestamp(), content);
}
}

Expand All @@ -640,6 +647,7 @@ private void handleRemoteDelete(@NonNull SignalServiceContent content, @NonNull
MessageNotifier.updateNotification(context, targetMessage.getThreadId(), false);
} else if (targetMessage == null) {
Log.w(TAG, "[handleRemoteDelete] Could not find matching message! timestamp: " + delete.getTargetSentTimestamp() + " author: " + sender.getId());
ApplicationDependencies.getEarlyMessageCache().store(sender.getId(), delete.getTargetSentTimestamp(), content);
} else {
Log.w(TAG, String.format(Locale.ENGLISH, "[handleRemoteDelete] Invalid remote delete! deleteTime: %d, targetTime: %d, deleteAuthor: %s, targetAuthor: %s",
content.getServerTimestamp(), targetMessage.getServerTimestamp(), sender.getId(), targetMessage.getRecipient().getId()));
Expand Down Expand Up @@ -1339,8 +1347,14 @@ private void handleReadReceipt(@NonNull SignalServiceContent content,
for (long timestamp : message.getTimestamps()) {
Log.i(TAG, String.format("Received encrypted read receipt: (XXXXX, %d)", timestamp));

DatabaseFactory.getMmsSmsDatabase(context)
.incrementReadReceiptCount(new SyncMessageId(Recipient.externalPush(context, content.getSender()).getId(), timestamp), content.getTimestamp());
Recipient sender = Recipient.externalPush(context, content.getSender());
SyncMessageId id = new SyncMessageId(sender.getId(), timestamp);
boolean handled = DatabaseFactory.getMmsSmsDatabase(context)
.incrementReadReceiptCount(id, content.getTimestamp());

if (!handled) {
ApplicationDependencies.getEarlyMessageCache().store(sender.getId(), timestamp, content);
}
}
}
}
Expand Down
Expand Up @@ -26,12 +26,15 @@
import org.whispersystems.signalservice.api.SignalServiceMessageSender;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccessPair;
import org.whispersystems.signalservice.api.crypto.UntrustedIdentityException;
import org.whispersystems.signalservice.api.messages.SendMessageResult;
import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceGroup;
import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.exceptions.UnregisteredUserException;

import java.io.IOException;
import java.util.List;

public class PushTextSendJob extends PushSendJob {

Expand Down
@@ -0,0 +1,60 @@
package org.thoughtcrime.securesms.util;

import androidx.annotation.NonNull;

import org.thoughtcrime.securesms.recipients.RecipientId;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.messages.SignalServiceContent;

import java.util.Objects;

/**
* Sometimes a message that is referencing another message can arrive out of order. In these cases,
* we want to temporarily hold on (i.e. keep a memory cache) to these messages and apply them after
* we receive the referenced message.
*/
public final class EarlyMessageCache {

private final LRUCache<MessageId, SignalServiceContent> cache = new LRUCache<>(100);

/**
* @param targetSender The sender of the message this message depends on.
* @param targetSentTimestamp The sent timestamp of the message this message depends on.
*/
public void store(@NonNull RecipientId targetSender, long targetSentTimestamp, @NonNull SignalServiceContent content) {
cache.put(new MessageId(targetSender, targetSentTimestamp), content);
}

/**
* Returns and removes any content that is dependent on the provided message id.
* @param sender The sender of the message in question.
* @param sentTimestamp The sent timestamp of the message in question.
*/
public Optional<SignalServiceContent> retrieve(@NonNull RecipientId sender, long sentTimestamp) {
return Optional.fromNullable(cache.remove(new MessageId(sender, sentTimestamp)));
}

private static final class MessageId {
private final RecipientId sender;
private final long sentTimestamp;

private MessageId(@NonNull RecipientId sender, long sentTimestamp) {
this.sender = sender;
this.sentTimestamp = sentTimestamp;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MessageId messageId = (MessageId) o;
return sentTimestamp == messageId.sentTimestamp &&
Objects.equals(sender, messageId.sender);
}

@Override
public int hashCode() {
return Objects.hash(sentTimestamp, sender);
}
}
}

0 comments on commit 83f6640

Please sign in to comment.