Skip to content

Commit

Permalink
Cancel typing jobs when you send a group message.
Browse files Browse the repository at this point in the history
  • Loading branch information
greyson-signal committed Jun 12, 2020
1 parent 8891b6c commit 3fad007
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 19 deletions.
Expand Up @@ -46,6 +46,7 @@
import org.whispersystems.signalservice.api.websocket.ConnectivityListener;

import java.util.UUID;
import java.util.concurrent.Executors;

/**
* Implementation of {@link ApplicationDependencies.Provider} that provides real app dependencies.
Expand Down Expand Up @@ -91,7 +92,7 @@ public ApplicationDependencyProvider(@NonNull Application context, @NonNull Sign
Optional.fromNullable(IncomingMessageObserver.getUnidentifiedPipe()),
Optional.of(new SecurityEventListener(context)),
provideClientZkOperations().getProfileOperations(),
SignalExecutors.UNBOUNDED);
SignalExecutors.newCachedBoundedExecutor("signal-messages", 1, 16));
}

@Override
Expand Down
Expand Up @@ -23,6 +23,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
Expand Down Expand Up @@ -150,6 +151,14 @@ synchronized void cancelJob(@NonNull String id) {
}
}

@WorkerThread
synchronized void cancelAllInQueue(@NonNull String queue) {
Stream.of(runningJobs.values())
.filter(j -> Objects.equals(j.getParameters().getQueue(), queue))
.map(Job::getId)
.forEach(this::cancelJob);
}

@WorkerThread
synchronized void onRetry(@NonNull Job job) {
int nextRunAttempt = job.getRunAttempt() + 1;
Expand Down
Expand Up @@ -198,6 +198,13 @@ public void cancel(@NonNull String id) {
executor.execute(() -> jobController.cancelJob(id));
}

/**
* Cancels all jobs in the specified queue. See {@link #cancel(String)} for details.
*/
public void cancelAllInQueue(@NonNull String queue) {
executor.execute(() -> jobController.cancelAllInQueue(queue));
}

/**
* Runs the specified job synchronously. Beware: All normal dependencies are respected, meaning
* you must take great care where you call this. It could take a very long time to complete!
Expand Down
Expand Up @@ -151,6 +151,9 @@ public void onPushSend()
List<NetworkFailure> existingNetworkFailures = message.getNetworkFailures();
List<IdentityKeyMismatch> existingIdentityMismatches = message.getIdentityKeyMismatches();

long threadId = DatabaseFactory.getThreadDatabase(context).getThreadIdFor(message.getRecipient());
ApplicationDependencies.getJobManager().cancelAllInQueue(TypingSendJob.getQueue(threadId));

if (database.isSent(messageId)) {
log(TAG, "Message " + messageId + " was already sent. Ignoring.");
return;
Expand Down
Expand Up @@ -15,6 +15,7 @@
import org.thoughtcrime.securesms.recipients.RecipientUtil;
import org.thoughtcrime.securesms.util.TextSecurePreferences;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.CancelationException;
import org.whispersystems.signalservice.api.SignalServiceMessageSender;
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccessPair;
import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage;
Expand All @@ -39,14 +40,18 @@ public class TypingSendJob extends BaseJob {

public TypingSendJob(long threadId, boolean typing) {
this(new Job.Parameters.Builder()
.setQueue("TYPING_" + threadId)
.setQueue(getQueue(threadId))
.setMaxAttempts(1)
.setLifespan(TimeUnit.SECONDS.toMillis(5))
.build(),
threadId,
typing);
}

public static String getQueue(long threadId) {
return "TYPING_" + threadId;
}

private TypingSendJob(@NonNull Job.Parameters parameters, long threadId, boolean typing) {
super(parameters);

Expand Down Expand Up @@ -101,7 +106,16 @@ public void onRun() throws Exception {
List<Optional<UnidentifiedAccessPair>> unidentifiedAccess = Stream.of(recipients).map(r -> UnidentifiedAccessUtil.getAccessFor(context, r)).toList();
SignalServiceTypingMessage typingMessage = new SignalServiceTypingMessage(typing ? Action.STARTED : Action.STOPPED, System.currentTimeMillis(), groupId);

messageSender.sendTyping(addresses, unidentifiedAccess, typingMessage);
if (isCanceled()) {
Log.w(TAG, "Canceled before send!");
return;
}

try {
messageSender.sendTyping(addresses, unidentifiedAccess, typingMessage, this::isCanceled);
} catch (CancelationException e) {
Log.w(TAG, "Canceled during send!");
}
}

@Override
Expand Down
Expand Up @@ -6,9 +6,13 @@

import org.thoughtcrime.securesms.util.LinkedBlockingLifoQueue;

import java.util.Queue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
Expand All @@ -26,6 +30,44 @@ public static ExecutorService newCachedSingleThreadExecutor(final String name) {
return executor;
}

/**
* ThreadPoolExecutor will only create a new thread if the provided queue returns false from
* offer(). That means if you give it an unbounded queue, it'll only ever create 1 thread, no
* matter how long the queue gets.
*
* But if you bound the queue and submit more runnables than there are threads, your task is
* rejected and throws an exception.
*
* So we make a queue that will always return false if it's non-empty to ensure new threads get
* created. Then, if a task gets rejected, we simply add it to the queue.
*/
public static ExecutorService newCachedBoundedExecutor(final String name, int minThreads, int maxThreads) {
ThreadPoolExecutor threadPool = new ThreadPoolExecutor(minThreads,
maxThreads,
30,
TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>() {
@Override
public boolean offer(Runnable runnable) {
if (isEmpty()) {
return super.offer(runnable);
} else {
return false;
}
}
}, new NumberedThreadFactory(name));

threadPool.setRejectedExecutionHandler((runnable, executor) -> {
try {
executor.getQueue().put(runnable);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});

return threadPool;
}

/**
* Returns an executor that prioritizes newer work. This is the opposite of a traditional executor,
* which processor work in FIFO order.
Expand Down
@@ -0,0 +1,6 @@
package org.whispersystems.signalservice.api;

import java.io.IOException;

public class CancelationException extends IOException {
}
Expand Up @@ -80,6 +80,7 @@
import org.whispersystems.signalservice.internal.push.exceptions.MismatchedDevicesException;
import org.whispersystems.signalservice.internal.push.exceptions.StaleDevicesException;
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory;
import org.whispersystems.signalservice.internal.push.http.CancelationSignal;
import org.whispersystems.signalservice.internal.push.http.ResumableUploadSpec;
import org.whispersystems.signalservice.internal.util.StaticCredentialsProvider;
import org.whispersystems.signalservice.internal.util.Util;
Expand Down Expand Up @@ -191,7 +192,7 @@ public void sendReceipt(SignalServiceAddress recipient,
{
byte[] content = createReceiptContent(message);

sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), message.getWhen(), content, false);
sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), message.getWhen(), content, false, null);
}

/**
Expand All @@ -209,16 +210,17 @@ public void sendTyping(SignalServiceAddress recipient,
{
byte[] content = createTypingContent(message);

sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), message.getTimestamp(), content, true);
sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), message.getTimestamp(), content, true, null);
}

public void sendTyping(List<SignalServiceAddress> recipients,
List<Optional<UnidentifiedAccessPair>> unidentifiedAccess,
SignalServiceTypingMessage message)
SignalServiceTypingMessage message,
CancelationSignal cancelationSignal)
throws IOException
{
byte[] content = createTypingContent(message);
sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), message.getTimestamp(), content, true);
sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), message.getTimestamp(), content, true, cancelationSignal);
}


Expand All @@ -235,7 +237,7 @@ public void sendCallMessage(SignalServiceAddress recipient,
throws IOException, UntrustedIdentityException
{
byte[] content = createCallContent(message);
sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), System.currentTimeMillis(), content, false);
sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), System.currentTimeMillis(), content, false, null);
}

/**
Expand All @@ -253,11 +255,11 @@ public SendMessageResult sendMessage(SignalServiceAddress recipient,
{
byte[] content = createMessageContent(message);
long timestamp = message.getTimestamp();
SendMessageResult result = sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, content, false);
SendMessageResult result = sendMessage(recipient, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, content, false, null);

if (result.getSuccess() != null && result.getSuccess().isNeedsSync()) {
byte[] syncMessage = createMultiDeviceSentTranscriptContent(content, Optional.of(recipient), timestamp, Collections.singletonList(result), false);
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), timestamp, syncMessage, false);
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), timestamp, syncMessage, false, null);
}

if (message.isEndSession()) {
Expand Down Expand Up @@ -291,7 +293,7 @@ public List<SendMessageResult> sendMessage(List<SignalServiceAddress>
{
byte[] content = createMessageContent(message);
long timestamp = message.getTimestamp();
List<SendMessageResult> results = sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, content, false);
List<SendMessageResult> results = sendMessage(recipients, getTargetUnidentifiedAccess(unidentifiedAccess), timestamp, content, false, null);
boolean needsSyncInResults = false;

for (SendMessageResult result : results) {
Expand All @@ -303,7 +305,7 @@ public List<SendMessageResult> sendMessage(List<SignalServiceAddress>

if (needsSyncInResults || isMultiDevice.get()) {
byte[] syncMessage = createMultiDeviceSentTranscriptContent(content, Optional.<SignalServiceAddress>absent(), timestamp, results, isRecipientUpdate);
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), timestamp, syncMessage, false);
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), timestamp, syncMessage, false, null);
}

return results;
Expand Down Expand Up @@ -347,7 +349,7 @@ public void sendMessage(SignalServiceSyncMessage message, Optional<UnidentifiedA
long timestamp = message.getSent().isPresent() ? message.getSent().get().getTimestamp()
: System.currentTimeMillis();

sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), timestamp, content, false);
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), timestamp, content, false, null);
}

public void setSoTimeoutMillis(long soTimeoutMillis) {
Expand Down Expand Up @@ -478,11 +480,11 @@ private void sendMessage(VerifiedMessage message, Optional<UnidentifiedAccessPai
.build()
.toByteArray();

SendMessageResult result = sendMessage(message.getDestination(), getTargetUnidentifiedAccess(unidentifiedAccess), message.getTimestamp(), content, false);
SendMessageResult result = sendMessage(message.getDestination(), getTargetUnidentifiedAccess(unidentifiedAccess), message.getTimestamp(), content, false, null);

if (result.getSuccess().isNeedsSync()) {
byte[] syncMessage = createMultiDeviceVerifiedContent(message, nullMessage.toByteArray());
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), message.getTimestamp(), syncMessage, false);
sendMessage(localAddress, Optional.<UnidentifiedAccess>absent(), message.getTimestamp(), syncMessage, false, null);
}
}

Expand Down Expand Up @@ -1200,7 +1202,8 @@ private List<SendMessageResult> sendMessage(List<SignalServiceAddress> r
List<Optional<UnidentifiedAccess>> unidentifiedAccess,
long timestamp,
byte[] content,
boolean online)
boolean online,
CancelationSignal cancelationSignal)
throws IOException
{
long startTime = System.currentTimeMillis();
Expand All @@ -1211,7 +1214,7 @@ private List<SendMessageResult> sendMessage(List<SignalServiceAddress> r
while (recipientIterator.hasNext()) {
SignalServiceAddress recipient = recipientIterator.next();
Optional<UnidentifiedAccess> access = unidentifiedAccessIterator.next();
futureResults.add(executor.submit(() -> sendMessage(recipient, access, timestamp, content, online)));
futureResults.add(executor.submit(() -> sendMessage(recipient, access, timestamp, content, online, cancelationSignal)));
}

List<SendMessageResult> results = new ArrayList<>(futureResults.size());
Expand Down Expand Up @@ -1247,14 +1250,24 @@ private SendMessageResult sendMessage(SignalServiceAddress recipient,
Optional<UnidentifiedAccess> unidentifiedAccess,
long timestamp,
byte[] content,
boolean online)
boolean online,
CancelationSignal cancelationSignal)
throws UntrustedIdentityException, IOException
{
long startTime = System.currentTimeMillis();

for (int i = 0; i < RETRY_COUNT; i++) {
if (cancelationSignal != null && cancelationSignal.isCanceled()) {
throw new CancelationException();
}

try {
OutgoingPushMessageList messages = getEncryptedMessages(socket, recipient, unidentifiedAccess, timestamp, content, online);
OutgoingPushMessageList messages = getEncryptedMessages(socket, recipient, unidentifiedAccess, timestamp, content, online);

if (cancelationSignal != null && cancelationSignal.isCanceled()) {
throw new CancelationException();
}

Optional<SignalServiceMessagePipe> pipe = this.pipe.get();
Optional<SignalServiceMessagePipe> unidentifiedPipe = this.unidentifiedPipe.get();

Expand All @@ -1278,6 +1291,10 @@ private SendMessageResult sendMessage(SignalServiceAddress recipient,
}
}

if (cancelationSignal != null && cancelationSignal.isCanceled()) {
throw new CancelationException();
}

SendMessageResponse response = socket.sendMessage(messages, unidentifiedAccess);

Log.d(TAG, "[sendMessage] Completed over REST in " + (System.currentTimeMillis() - startTime) + " ms and " + (i + 1) + " attempt(s)");
Expand Down

0 comments on commit 3fad007

Please sign in to comment.