Skip to content

Commit

Permalink
Guard against malformed group ids.
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-signal authored and greyson-signal committed Apr 21, 2020
1 parent 00ee6d0 commit 9a8094c
Show file tree
Hide file tree
Showing 21 changed files with 200 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ private Drawable getDefaultGroupAvatar() {
}

private void initializeExistingGroup() {
final GroupId groupId = GroupId.parseNullable(getIntent().getStringExtra(GROUP_ID_EXTRA));
final GroupId groupId = GroupId.parseNullableOrThrow(getIntent().getStringExtra(GROUP_ID_EXTRA));

if (groupId != null) {
new FillExistingGroupInfoAsyncTask(this).executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR, groupId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.thoughtcrime.securesms.logging.Log;
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.recipients.RecipientId;
import org.thoughtcrime.securesms.util.Util;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.signalservice.api.groupsv2.DecryptedGroupUtil;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
Expand Down Expand Up @@ -186,7 +185,7 @@ public GroupId.Mms getOrCreateMmsGroupForMembers(List<RecipientId> members) {
null, null, null);
try {
if (cursor != null && cursor.moveToNext()) {
return GroupId.parse(cursor.getString(cursor.getColumnIndexOrThrow(GROUP_ID)))
return GroupId.parseOrThrow(cursor.getString(cursor.getColumnIndexOrThrow(GROUP_ID)))
.requireMms();
} else {
GroupId.Mms groupId = GroupId.createMms(new SecureRandom());
Expand Down Expand Up @@ -519,7 +518,7 @@ public Reader(Cursor cursor) {
return null;
}

return new GroupRecord(GroupId.parse(cursor.getString(cursor.getColumnIndexOrThrow(GROUP_ID))),
return new GroupRecord(GroupId.parseOrThrow(cursor.getString(cursor.getColumnIndexOrThrow(GROUP_ID))),
RecipientId.from(cursor.getLong(cursor.getColumnIndexOrThrow(RECIPIENT_ID))),
cursor.getString(cursor.getColumnIndexOrThrow(TITLE)),
cursor.getString(cursor.getColumnIndexOrThrow(MEMBERS)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ public void applyStorageSyncUpdates(@NonNull Collection<SignalContactRecord>
for (SignalGroupV1Record insert : groupV1Inserts) {
db.insertOrThrow(TABLE_NAME, null, getValuesForStorageGroupV1(insert));

Recipient recipient = Recipient.externalGroup(context, GroupId.v1(insert.getGroupId()));
Recipient recipient = Recipient.externalGroup(context, GroupId.v1orThrow(insert.getGroupId()));

threadDatabase.setArchived(recipient.getId(), insert.isArchived());
recipient.live().refresh();
Expand All @@ -575,7 +575,7 @@ public void applyStorageSyncUpdates(@NonNull Collection<SignalContactRecord>
throw new AssertionError("Had an update, but it didn't match any rows!");
}

Recipient recipient = Recipient.externalGroup(context, GroupId.v1(update.getOld().getGroupId()));
Recipient recipient = Recipient.externalGroup(context, GroupId.v1orThrow(update.getOld().getGroupId()));

threadDatabase.setArchived(recipient.getId(), update.getNew().isArchived());
recipient.live().refresh();
Expand Down Expand Up @@ -670,7 +670,7 @@ public void updatePhoneNumbers(@NonNull Map<String, String> mapping) {

private static @NonNull ContentValues getValuesForStorageGroupV1(@NonNull SignalGroupV1Record groupV1) {
ContentValues values = new ContentValues();
values.put(GROUP_ID, GroupId.v1(groupV1.getGroupId()).toString());
values.put(GROUP_ID, GroupId.v1orThrow(groupV1.getGroupId()).toString());
values.put(GROUP_TYPE, GroupType.SIGNAL_V1.getId());
values.put(PROFILE_SHARING, groupV1.isProfileSharingEnabled() ? "1" : "0");
values.put(BLOCKED, groupV1.isBlocked() ? "1" : "0");
Expand Down Expand Up @@ -733,7 +733,7 @@ public List<StorageId> getContactStorageSyncIds() {
String username = cursor.getString(cursor.getColumnIndexOrThrow(USERNAME));
String e164 = cursor.getString(cursor.getColumnIndexOrThrow(PHONE));
String email = cursor.getString(cursor.getColumnIndexOrThrow(EMAIL));
GroupId groupId = GroupId.parseNullable(cursor.getString(cursor.getColumnIndexOrThrow(GROUP_ID)));
GroupId groupId = GroupId.parseNullableOrThrow(cursor.getString(cursor.getColumnIndexOrThrow(GROUP_ID)));
int groupType = cursor.getInt(cursor.getColumnIndexOrThrow(GROUP_TYPE));
boolean blocked = cursor.getInt(cursor.getColumnIndexOrThrow(BLOCKED)) == 1;
String messageRingtone = cursor.getString(cursor.getColumnIndexOrThrow(MESSAGE_RINGTONE));
Expand Down Expand Up @@ -1406,7 +1406,7 @@ public void applyBlockedUpdate(@NonNull List<SignalServiceAddress> blocked, List
db.update(TABLE_NAME, setBlocked, UUID + " = ?", new String[] { uuid });
}

List<GroupId.V1> groupIdStrings = Stream.of(groupIds).map(GroupId::v1).toList();
List<GroupId.V1> groupIdStrings = Stream.of(groupIds).map(GroupId::v1orThrow).toList();

for (GroupId.V1 groupId : groupIdStrings) {
db.update(TABLE_NAME, setBlocked, GROUP_ID + " = ?", new String[] { groupId.toString() });
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.thoughtcrime.securesms.groups;

public final class BadGroupIdException extends Exception {
BadGroupIdException(String message) {
super(message);
}

BadGroupIdException() {
super();
}

BadGroupIdException(Exception e) {
super(e);
}
}
62 changes: 51 additions & 11 deletions app/src/main/java/org/thoughtcrime/securesms/groups/GroupId.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,46 @@ private GroupId(@NonNull String prefix, @NonNull byte[] bytes) {
return new GroupId.Mms(mmsGroupIdBytes);
}

public static @NonNull GroupId.V1 v1(byte[] gv1GroupIdBytes) {
public static @NonNull GroupId.V1 v1orThrow(byte[] gv1GroupIdBytes) {
try {
return v1(gv1GroupIdBytes);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}

public static @NonNull GroupId.V1 v1(byte[] gv1GroupIdBytes) throws BadGroupIdException {
if (gv1GroupIdBytes.length == V2_BYTE_LENGTH) {
throw new AssertionError();
throw new BadGroupIdException();
}
return new GroupId.V1(gv1GroupIdBytes);
}

public static GroupId.V1 createV1(@NonNull SecureRandom secureRandom) {
return v1(Util.getSecretBytes(secureRandom, V1_MMS_BYTE_LENGTH));
return v1orThrow(Util.getSecretBytes(secureRandom, V1_MMS_BYTE_LENGTH));
}

public static GroupId.Mms createMms(@NonNull SecureRandom secureRandom) {
return mms(Util.getSecretBytes(secureRandom, MMS_BYTE_LENGTH));
}

public static GroupId.V2 v2(@NonNull byte[] bytes) {
public static GroupId.V2 v2orThrow(@NonNull byte[] bytes) {
try {
return v2(bytes);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}

public static GroupId.V2 v2(@NonNull byte[] bytes) throws BadGroupIdException {
if (bytes.length != V2_BYTE_LENGTH) {
throw new AssertionError();
throw new BadGroupIdException();
}
return new GroupId.V2(bytes);
}

public static GroupId.V2 v2(@NonNull GroupIdentifier groupIdentifier) {
return v2(groupIdentifier.serialize());
return v2orThrow(groupIdentifier.serialize());
}

public static GroupId.V2 v2(@NonNull GroupMasterKey masterKey) {
Expand All @@ -62,32 +78,56 @@ public static GroupId.V2 v2(@NonNull GroupMasterKey masterKey) {
.getGroupIdentifier());
}

public static GroupId.Push push(byte[] bytes) {
public static GroupId.Push push(byte[] bytes) throws BadGroupIdException {
return bytes.length == V2_BYTE_LENGTH ? v2(bytes) : v1(bytes);
}

public static @NonNull GroupId parse(@NonNull String encodedGroupId) {
public static GroupId.Push pushOrThrow(byte[] bytes) {
try {
return push(bytes);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}

public static @NonNull GroupId parseOrThrow(@NonNull String encodedGroupId) {
try {
return parse(encodedGroupId);
} catch (BadGroupIdException e) {
throw new AssertionError(e);
}
}

public static @NonNull GroupId parse(@NonNull String encodedGroupId) throws BadGroupIdException {
try {
if (!isEncodedGroup(encodedGroupId)) {
throw new IOException("Invalid encoding");
throw new BadGroupIdException("Invalid encoding");
}

byte[] bytes = extractDecodedId(encodedGroupId);

return encodedGroupId.startsWith(ENCODED_MMS_GROUP_PREFIX) ? mms(bytes) : push(bytes);
} catch (IOException e) {
throw new AssertionError(e);
throw new BadGroupIdException(e);
}
}

public static @Nullable GroupId parseNullable(@Nullable String encodedGroupId) {
public static @Nullable GroupId parseNullable(@Nullable String encodedGroupId) throws BadGroupIdException {
if (encodedGroupId == null) {
return null;
}

return parse(encodedGroupId);
}

public static @Nullable GroupId parseNullableOrThrow(@Nullable String encodedGroupId) {
if (encodedGroupId == null) {
return null;
}

return parseOrThrow(encodedGroupId);
}

public static boolean isEncodedGroup(@NonNull String groupId) {
return groupId.startsWith(ENCODED_SIGNAL_GROUP_PREFIX) || groupId.startsWith(ENCODED_MMS_GROUP_PREFIX);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public final class GroupV1MessageProcessor {

GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
SignalServiceGroup group = groupV1.get();
GroupId id = GroupId.v1(group.getGroupId());
GroupId id = GroupId.v1orThrow(group.getGroupId());
Optional<GroupRecord> record = database.getGroup(id);

if (record.isPresent() && group.getType() == Type.UPDATE) {
Expand All @@ -93,7 +93,7 @@ public final class GroupV1MessageProcessor {
boolean outgoing)
{
GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
GroupId.V1 id = GroupId.v1(group.getGroupId());
GroupId.V1 id = GroupId.v1orThrow(group.getGroupId());
GroupContext.Builder builder = createGroupContext(group);
builder.setType(GroupContext.Type.UPDATE);

Expand Down Expand Up @@ -127,7 +127,7 @@ public final class GroupV1MessageProcessor {
{

GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
GroupId.V1 id = GroupId.v1(group.getGroupId());
GroupId.V1 id = GroupId.v1orThrow(group.getGroupId());

Set<RecipientId> recordMembers = new HashSet<>(groupRecord.getMembers());
Set<RecipientId> messageMembers = new HashSet<>();
Expand Down Expand Up @@ -203,7 +203,7 @@ private static Long handleGroupLeave(@NonNull Context context,
boolean outgoing)
{
GroupDatabase database = DatabaseFactory.getGroupDatabase(context);
GroupId id = GroupId.v1(group.getGroupId());
GroupId id = GroupId.v1orThrow(group.getGroupId());
List<RecipientId> members = record.getMembers();

GroupContext.Builder builder = createGroupContext(group);
Expand All @@ -228,13 +228,13 @@ private static Long handleGroupLeave(@NonNull Context context,
{
if (group.getAvatar().isPresent()) {
ApplicationDependencies.getJobManager()
.add(new AvatarGroupsV1DownloadJob(GroupId.v1(group.getGroupId())));
.add(new AvatarGroupsV1DownloadJob(GroupId.v1orThrow(group.getGroupId())));
}

try {
if (outgoing) {
MmsDatabase mmsDatabase = DatabaseFactory.getMmsDatabase(context);
RecipientId recipientId = DatabaseFactory.getRecipientDatabase(context).getOrInsertFromGroupId(GroupId.v1(group.getGroupId()));
RecipientId recipientId = DatabaseFactory.getRecipientDatabase(context).getOrInsertFromGroupId(GroupId.v1orThrow(group.getGroupId()));
Recipient recipient = Recipient.resolved(recipientId);
OutgoingGroupMediaMessage outgoingMessage = new OutgoingGroupMediaMessage(recipient, storage, null, content.getTimestamp(), 0, false, null, Collections.emptyList(), Collections.emptyList());
long threadId = DatabaseFactory.getThreadDatabase(context).getThreadIdFor(recipient);
Expand All @@ -246,7 +246,7 @@ private static Long handleGroupLeave(@NonNull Context context,
} else {
SmsDatabase smsDatabase = DatabaseFactory.getSmsDatabase(context);
String body = Base64.encodeBytes(storage.toByteArray());
IncomingTextMessage incoming = new IncomingTextMessage(Recipient.externalPush(context, content.getSender()).getId(), content.getSenderDevice(), content.getTimestamp(), content.getServerTimestamp(), body, Optional.of(GroupId.v1(group.getGroupId())), 0, content.isNeedsReceipt());
IncomingTextMessage incoming = new IncomingTextMessage(Recipient.externalPush(context, content.getSender()).getId(), content.getSenderDevice(), content.getTimestamp(), content.getServerTimestamp(), body, Optional.of(GroupId.v1orThrow(group.getGroupId())), 0, content.isNeedsReceipt());
IncomingGroupMessage groupMessage = new IncomingGroupMessage(incoming, storage, body);

Optional<InsertResult> insertResult = smsDatabase.insertMessageInbox(groupMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ protected void onCreate(Bundle savedInstanceState, boolean ready) {

if (savedInstanceState == null) {
getSupportFragmentManager().beginTransaction()
.replace(R.id.container, PendingMemberInvitesFragment.newInstance(GroupId.parse(getIntent().getStringExtra(GROUP_ID)).requireV2()))
.replace(R.id.container, PendingMemberInvitesFragment.newInstance(GroupId.parseOrThrow(getIntent().getStringExtra(GROUP_ID)).requireV2()))
.commitNow();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void onCancelAllInvites(@NonNull GroupMemberEntry.UnknownPendingMemberCou
public void onActivityCreated(@Nullable Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);

GroupId.V2 groupId = GroupId.parse(Objects.requireNonNull(requireArguments().getString(GROUP_ID))).requireV2();
GroupId.V2 groupId = GroupId.parseOrThrow(Objects.requireNonNull(requireArguments().getString(GROUP_ID))).requireV2();

PendingMemberInvitesViewModel.Factory factory = new PendingMemberInvitesViewModel.Factory(requireContext(), groupId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public boolean onShouldRetry(@NonNull Exception exception) {
public static final class Factory implements Job.Factory<AvatarGroupsV1DownloadJob> {
@Override
public @NonNull AvatarGroupsV1DownloadJob create(@NonNull Parameters parameters, @NonNull Data data) {
return new AvatarGroupsV1DownloadJob(parameters, GroupId.parse(data.getString(KEY_GROUP_ID)).requireV1());
return new AvatarGroupsV1DownloadJob(parameters, GroupId.parseOrThrow(data.getString(KEY_GROUP_ID)).requireV1());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public static final class Factory implements Job.Factory<AvatarGroupsV2DownloadJ
@Override
public @NonNull AvatarGroupsV2DownloadJob create(@NonNull Parameters parameters, @NonNull Data data) {
return new AvatarGroupsV2DownloadJob(parameters,
GroupId.parse(data.getString(KEY_GROUP_ID)).requireV2(),
GroupId.parseOrThrow(data.getString(KEY_GROUP_ID)).requireV2(),
data.getString(CDN_KEY));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public void onFailure() {
public static class Factory implements Job.Factory<LeaveGroupJob> {
@Override
public @NonNull LeaveGroupJob create(@NonNull Parameters parameters, @NonNull Data data) {
return new LeaveGroupJob(GroupId.v1(Base64.decodeOrThrow(data.getString(KEY_GROUP_ID))),
return new LeaveGroupJob(GroupId.v1orThrow(Base64.decodeOrThrow(data.getString(KEY_GROUP_ID))),
data.getString(KEY_GROUP_NAME),
RecipientId.fromSerializedList(data.getString(KEY_MEMBERS)),
RecipientId.fromSerializedList(data.getString(KEY_RECIPIENTS)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.thoughtcrime.securesms.database.NoSuchMessageException;
import org.thoughtcrime.securesms.database.PushDatabase;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.groups.BadGroupIdException;
import org.thoughtcrime.securesms.groups.GroupId;
import org.thoughtcrime.securesms.jobmanager.Data;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.JobManager;
Expand Down Expand Up @@ -226,14 +228,23 @@ private void postMigrationNotification() {
}
}

private static PushProcessMessageJob.ExceptionMetadata toExceptionMetadata(@NonNull UnsupportedDataMessageException e) throws NoSenderException {
private static PushProcessMessageJob.ExceptionMetadata toExceptionMetadata(@NonNull UnsupportedDataMessageException e)
throws NoSenderException
{
String sender = e.getSender();

if (sender == null) throw new NoSenderException();

GroupId groupId = null;
try {
groupId = GroupUtil.idFromGroupContext(e.getGroup().orNull());
} catch (BadGroupIdException ex) {
Log.w(TAG, "Bad group id found in unsupported data message", ex);
}

return new PushProcessMessageJob.ExceptionMetadata(sender,
e.getSenderDevice(),
e.getGroup().transform(GroupUtil::idFromGroupContext).orNull());
groupId);
}

private static PushProcessMessageJob.ExceptionMetadata toExceptionMetadata(@NonNull ProtocolException e) throws NoSenderException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -140,7 +139,7 @@ public static final class Factory implements Job.Factory<PushGroupUpdateJob> {
public @NonNull PushGroupUpdateJob create(@NonNull Parameters parameters, @NonNull org.thoughtcrime.securesms.jobmanager.Data data) {
return new PushGroupUpdateJob(parameters,
RecipientId.from(data.getString(KEY_SOURCE)),
GroupId.parse(data.getString(KEY_GROUP_ID)));
GroupId.parseOrThrow(data.getString(KEY_GROUP_ID)));
}
}
}

0 comments on commit 9a8094c

Please sign in to comment.