Skip to content

Commit

Permalink
Make PreKeyWhisperMessage decrypt more reliably atomic.
Browse files Browse the repository at this point in the history
  • Loading branch information
moxie0 committed Oct 20, 2014
1 parent 1eb3884 commit c330eef
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,35 +88,32 @@ public SessionBuilder(SessionStore sessionStore,
* @throws org.whispersystems.libaxolotl.InvalidKeyException when the message is formatted incorrectly.
* @throws org.whispersystems.libaxolotl.UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted.
*/
/*package*/ boolean process(PreKeyWhisperMessage message)
/*package*/ void process(SessionRecord sessionRecord, PreKeyWhisperMessage message)
throws InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException
{
int messageVersion = message.getMessageVersion();
IdentityKey theirIdentityKey = message.getIdentityKey();

boolean createdSession;

if (!identityKeyStore.isTrustedIdentity(recipientId, theirIdentityKey)) {
throw new UntrustedIdentityException();
}

if (messageVersion == 2) createdSession = processV2(message);
else if (messageVersion == 3) createdSession = processV3(message);
else throw new AssertionError("Unknown version: " + messageVersion);
switch (messageVersion) {
case 2: processV2(sessionRecord, message); break;
case 3: processV3(sessionRecord, message); break;
default: throw new AssertionError("Unknown version: " + messageVersion);
}

identityKeyStore.saveIdentity(recipientId, theirIdentityKey);

return createdSession;
}

private boolean processV3(PreKeyWhisperMessage message)
private void processV3(SessionRecord sessionRecord, PreKeyWhisperMessage message)
throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException
{
SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);

if (sessionRecord.hasSessionState(message.getMessageVersion(), message.getBaseKey().serialize())) {
Log.w(TAG, "We've already setup a session for this V3 message, letting bundled message fall through...");
return false;
return;
}

boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage();
Expand Down Expand Up @@ -147,27 +144,22 @@ private boolean processV3(PreKeyWhisperMessage message)

if (simultaneousInitiate) sessionRecord.getSessionState().setNeedsRefresh(true);

sessionStore.storeSession(recipientId, deviceId, sessionRecord);

if (message.getPreKeyId() >= 0 && message.getPreKeyId() != Medium.MAX_VALUE) {
preKeyStore.removePreKey(message.getPreKeyId());
}

return true;
}

private boolean processV2(PreKeyWhisperMessage message)
private void processV2(SessionRecord sessionRecord, PreKeyWhisperMessage message)
throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException
{

if (!preKeyStore.containsPreKey(message.getPreKeyId()) &&
sessionStore.containsSession(recipientId, deviceId))
{
Log.w(TAG, "We've already processed the prekey part of this V2 session, letting bundled message fall through...");
return false;
return;
}

SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
ECKeyPair ourPreKey = preKeyStore.loadPreKey(message.getPreKeyId()).getKeyPair();
boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage();

Expand All @@ -193,10 +185,6 @@ private boolean processV2(PreKeyWhisperMessage message)
if (message.getPreKeyId() != Medium.MAX_VALUE) {
preKeyStore.removePreKey(message.getPreKeyId());
}

sessionStore.storeSession(recipientId, deviceId, sessionRecord);

return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,13 @@ public byte[] decrypt(PreKeyWhisperMessage ciphertext)
InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException, NoSessionException
{
synchronized (SESSION_LOCK) {
boolean sessionCreated = sessionBuilder.process(ciphertext);
SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);

try {
return decrypt(ciphertext.getWhisperMessage());
} catch (InvalidMessageException | DuplicateMessageException | LegacyMessageException e) {
if (sessionCreated) {
sessionStore.deleteSession(recipientId, deviceId);
}
sessionBuilder.process(sessionRecord, ciphertext);
byte[] plaintext = decrypt(sessionRecord, ciphertext.getWhisperMessage());

throw e;
}
sessionStore.storeSession(recipientId, deviceId, sessionRecord);
return plaintext;
}
}

Expand All @@ -182,26 +178,32 @@ public byte[] decrypt(WhisperMessage ciphertext)
throw new NoSessionException("No session for: " + recipientId + ", " + deviceId);
}

SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
byte[] plaintext = decrypt(sessionRecord, ciphertext);

sessionStore.storeSession(recipientId, deviceId, sessionRecord);

return plaintext;
}
}

private byte[] decrypt(SessionRecord sessionRecord, WhisperMessage ciphertext)
throws DuplicateMessageException, LegacyMessageException, InvalidMessageException
{
synchronized (SESSION_LOCK) {
SessionState sessionState = sessionRecord.getSessionState();
List<SessionState> previousStates = sessionRecord.getPreviousSessionStates();
List<Exception> exceptions = new LinkedList<>();

try {
byte[] plaintext = decrypt(sessionState, ciphertext);
sessionStore.storeSession(recipientId, deviceId, sessionRecord);

return plaintext;
return decrypt(sessionState, ciphertext);
} catch (InvalidMessageException e) {
exceptions.add(e);
}

for (SessionState previousState : previousStates) {
try {
byte[] plaintext = decrypt(previousState, ciphertext);
sessionStore.storeSession(recipientId, deviceId, sessionRecord);

return plaintext;
return decrypt(previousState, ciphertext);
} catch (InvalidMessageException e) {
exceptions.add(e);
}
Expand Down Expand Up @@ -240,7 +242,6 @@ private byte[] decrypt(SessionState sessionState, WhisperMessage ciphertextMessa
sessionState.clearUnacknowledgedPreKeyMessage();

return plaintext;

}

public int getRemoteRegistrationId() {
Expand Down

0 comments on commit c330eef

Please sign in to comment.