Skip to content

Commit

Permalink
Fix possible dlist sync crash, improved debugging.
Browse files Browse the repository at this point in the history
Fixes #12795
  • Loading branch information
greyson-signal committed Feb 22, 2023
1 parent b689ea6 commit 691ab35
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DistributionListTables constructor(context: Context?, databaseHelper: Sign
}
}

private object ListTable {
object ListTable {
const val TABLE_NAME = "distribution_list"

const val ID = "_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1170,36 +1170,47 @@ open class RecipientTable(context: Context, databaseHelper: SignalDatabase) : Da
* @return All storage IDs for synced records, excluding the ones that need to be deleted.
*/
fun getContactStorageSyncIdsMap(): Map<RecipientId, StorageId> {
val inPart = "(?, ?)"
val args = SqlUtil.buildArgs(GroupType.NONE.id, Recipient.self().id, GroupType.SIGNAL_V1.id, GroupType.DISTRIBUTION_LIST.id)

val query = """
$STORAGE_SERVICE_ID NOT NULL AND (
($GROUP_TYPE = ? AND $SERVICE_ID NOT NULL AND $ID != ?)
OR
$GROUP_TYPE IN $inPart
)
""".trimIndent()
val out: MutableMap<RecipientId, StorageId> = HashMap()

readableDatabase.query(TABLE_NAME, arrayOf(ID, STORAGE_SERVICE_ID, GROUP_TYPE), query, args, null, null, null).use { cursor ->
while (cursor != null && cursor.moveToNext()) {
val id = RecipientId.from(cursor.requireLong(ID))
val encodedKey = cursor.requireNonNullString(STORAGE_SERVICE_ID)
val groupType = GroupType.fromId(cursor.requireInt(GROUP_TYPE))
val key = Base64.decodeOrThrow(encodedKey)

when (groupType) {
GroupType.NONE -> out[id] = StorageId.forContact(key)
GroupType.SIGNAL_V1 -> out[id] = StorageId.forGroupV1(key)
GroupType.DISTRIBUTION_LIST -> out[id] = StorageId.forStoryDistributionList(key)
else -> throw AssertionError()
readableDatabase
.select(ID, STORAGE_SERVICE_ID, GROUP_TYPE)
.from(TABLE_NAME)
.where(
"""
$STORAGE_SERVICE_ID NOT NULL AND (
($GROUP_TYPE = ? AND $SERVICE_ID NOT NULL AND $ID != ?)
OR
$GROUP_TYPE = ?
OR
$DISTRIBUTION_LIST_ID NOT NULL AND $DISTRIBUTION_LIST_ID IN (
SELECT ${DistributionListTables.ListTable.ID}
FROM ${DistributionListTables.ListTable.TABLE_NAME}
)
)
""".toSingleLine(),
GroupType.NONE.id,
Recipient.self().id,
GroupType.SIGNAL_V1.id
)
.run()
.use { cursor ->
while (cursor.moveToNext()) {
val id = RecipientId.from(cursor.requireLong(ID))
val encodedKey = cursor.requireNonNullString(STORAGE_SERVICE_ID)
val groupType = GroupType.fromId(cursor.requireInt(GROUP_TYPE))
val key = Base64.decodeOrThrow(encodedKey)

when (groupType) {
GroupType.NONE -> out[id] = StorageId.forContact(key)
GroupType.SIGNAL_V1 -> out[id] = StorageId.forGroupV1(key)
GroupType.DISTRIBUTION_LIST -> out[id] = StorageId.forStoryDistributionList(key)
else -> throw AssertionError()
}
}
}
}

for (id in groups.getAllGroupV2Ids()) {
val recipient = Recipient.externalGroupExact(id!!)
val recipient = Recipient.externalGroupExact(id)
val recipientId = recipient.id
val existing: RecipientRecord = getRecordForSync(recipientId) ?: throw AssertionError()
val key = existing.storageId ?: throw AssertionError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.util.Base64;
import org.signal.core.util.SetUtil;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.storage.SignalContactRecord;
import org.whispersystems.signalservice.api.storage.SignalStorageManifest;
import org.whispersystems.signalservice.api.storage.SignalStorageRecord;
Expand Down Expand Up @@ -103,6 +102,19 @@ public static void validateForcePush(@NonNull SignalStorageManifest manifest, @N
}

private static void validateManifestAndInserts(@NonNull SignalStorageManifest manifest, @NonNull List<SignalStorageRecord> inserts, @NonNull Recipient self) {
int accountCount = 0;
for (StorageId id : manifest.getStorageIds()) {
accountCount += id.getType() == ManifestRecord.Identifier.Type.ACCOUNT_VALUE ? 1 : 0;
}

if (accountCount > 1) {
throw new MultipleAccountError();
}

if (accountCount == 0) {
throw new MissingAccountError();
}

Set<StorageId> allSet = new HashSet<>(manifest.getStorageIds());
Set<StorageId> insertSet = new HashSet<>(Stream.of(inserts).map(SignalStorageRecord::getId).toList());
Set<ByteBuffer> rawIdSet = Stream.of(allSet).map(id -> ByteBuffer.wrap(id.getRaw())).collect(Collectors.toSet());
Expand All @@ -112,26 +124,34 @@ private static void validateManifestAndInserts(@NonNull SignalStorageManifest ma
}

if (rawIdSet.size() != allSet.size()) {
throw new DuplicateRawIdError();
}
List<StorageId> ids = manifest.getStorageIdsByType().get(ManifestRecord.Identifier.Type.CONTACT_VALUE);
if (ids.size() != new HashSet<>(ids).size()) {
throw new DuplicateContactIdError();
}

if (inserts.size() > insertSet.size()) {
throw new DuplicateInsertInWriteError();
}
ids = manifest.getStorageIdsByType().get(ManifestRecord.Identifier.Type.GROUPV1_VALUE);
if (ids.size() != new HashSet<>(ids).size()) {
throw new DuplicateGroupV1IdError();
}

int accountCount = 0;
for (StorageId id : manifest.getStorageIds()) {
accountCount += id.getType() == ManifestRecord.Identifier.Type.ACCOUNT_VALUE ? 1 : 0;
}
ids = manifest.getStorageIdsByType().get(ManifestRecord.Identifier.Type.GROUPV2_VALUE);
if (ids.size() != new HashSet<>(ids).size()) {
throw new DuplicateGroupV2IdError();
}

if (accountCount > 1) {
throw new MultipleAccountError();
ids = manifest.getStorageIdsByType().get(ManifestRecord.Identifier.Type.STORY_DISTRIBUTION_LIST_VALUE);
if (ids.size() != new HashSet<>(ids).size()) {
throw new DuplicateDistributionListIdError();
}

throw new DuplicateRawIdAcrossTypesError();
}

if (accountCount == 0) {
throw new MissingAccountError();
if (inserts.size() > insertSet.size()) {
throw new DuplicateInsertInWriteError();
}


for (SignalStorageRecord insert : inserts) {
if (!allSet.contains(insert.getId())) {
throw new InsertNotPresentInFullIdSetError();
Expand Down Expand Up @@ -161,7 +181,19 @@ private static void validateManifestAndInserts(@NonNull SignalStorageManifest ma
private static final class DuplicateStorageIdError extends Error {
}

private static final class DuplicateRawIdError extends Error {
private static final class DuplicateRawIdAcrossTypesError extends Error {
}

private static final class DuplicateContactIdError extends Error {
}

private static final class DuplicateGroupV1IdError extends Error {
}

private static final class DuplicateGroupV2IdError extends Error {
}

private static final class DuplicateDistributionListIdError extends Error {
}

private static final class DuplicateInsertInWriteError extends Error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ public Optional<StorageId> getAccountStorageId() {
}
}

public Map<Integer, List<StorageId>> getStorageIdsByType() {
return storageIdsByType;
}

public byte[] serialize() {
List<ManifestRecord.Identifier> ids = new ArrayList<>(storageIds.size());

Expand Down

0 comments on commit 691ab35

Please sign in to comment.