Skip to content

Commit

Permalink
Fix discrepancy in message counting between export and import backups.
Browse files Browse the repository at this point in the history
  • Loading branch information
cody-signal authored and greyson-signal committed Mar 17, 2021
1 parent cb6e3ad commit 9366596
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import androidx.annotation.RequiresApi;
import androidx.documentfile.provider.DocumentFile;

import com.annimon.stream.function.Consumer;
import com.annimon.stream.function.Predicate;
import com.google.protobuf.ByteString;

Expand Down Expand Up @@ -127,8 +126,10 @@ private static void internalExport(@NonNull Context context,

try {
outputStream.writeDatabaseVersion(input.getVersion());
count++;

List<String> tables = exportSchema(input, outputStream);
count += tables.size() * 3;

Stopwatch stopwatch = new Stopwatch("Backup");

Expand All @@ -141,9 +142,9 @@ private static void internalExport(@NonNull Context context,
} else if (table.equals(GroupReceiptDatabase.TABLE_NAME)) {
count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(GroupReceiptDatabase.MMS_ID))), null, count, cancellationSignal);
} else if (table.equals(AttachmentDatabase.TABLE_NAME)) {
count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.MMS_ID))), cursor -> exportAttachment(attachmentSecret, cursor, outputStream), count, cancellationSignal);
count = exportTable(table, input, outputStream, cursor -> isForNonExpiringMessage(input, cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.MMS_ID))), (cursor, innerCount) -> exportAttachment(attachmentSecret, cursor, outputStream, innerCount), count, cancellationSignal);
} else if (table.equals(StickerDatabase.TABLE_NAME)) {
count = exportTable(table, input, outputStream, cursor -> true, cursor -> exportSticker(attachmentSecret, cursor, outputStream), count, cancellationSignal);
count = exportTable(table, input, outputStream, cursor -> true, (cursor, innerCount) -> exportSticker(attachmentSecret, cursor, outputStream, innerCount), count, cancellationSignal);
} else if (!BLACKLISTED_TABLES.contains(table) && !table.startsWith("sqlite_")) {
count = exportTable(table, input, outputStream, null, null, count, cancellationSignal);
}
Expand Down Expand Up @@ -224,7 +225,7 @@ private static int exportTable(@NonNull String table,
@NonNull SQLiteDatabase input,
@NonNull BackupFrameOutputStream outputStream,
@Nullable Predicate<Cursor> predicate,
@Nullable Consumer<Cursor> postProcess,
@Nullable PostProcessor postProcess,
int count,
@NonNull BackupCancellationSignal cancellationSignal)
throws IOException
Expand All @@ -234,7 +235,6 @@ private static int exportTable(@NonNull String table,
try (Cursor cursor = input.rawQuery("SELECT * FROM " + table, null)) {
while (cursor != null && cursor.moveToNext()) {
throwIfCanceled(cancellationSignal);
EventBus.getDefault().post(new BackupEvent(BackupEvent.Type.PROGRESS, ++count));

if (predicate == null || predicate.test(cursor)) {
StringBuilder statement = new StringBuilder(template);
Expand Down Expand Up @@ -266,17 +266,20 @@ private static int exportTable(@NonNull String table,

statement.append(')');

EventBus.getDefault().post(new BackupEvent(BackupEvent.Type.PROGRESS, ++count));
outputStream.write(statementBuilder.setStatement(statement.toString()).build());

if (postProcess != null) postProcess.accept(cursor);
if (postProcess != null) {
count = postProcess.postProcess(cursor, count);
}
}
}
}

return count;
}

private static void exportAttachment(@NonNull AttachmentSecret attachmentSecret, @NonNull Cursor cursor, @NonNull BackupFrameOutputStream outputStream) {
private static int exportAttachment(@NonNull AttachmentSecret attachmentSecret, @NonNull Cursor cursor, @NonNull BackupFrameOutputStream outputStream, int count) {
try {
long rowId = cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.ROW_ID));
long uniqueId = cursor.getLong(cursor.getColumnIndexOrThrow(AttachmentDatabase.UNIQUE_ID));
Expand All @@ -301,14 +304,17 @@ private static void exportAttachment(@NonNull AttachmentSecret attachmentSecret,
if (random != null && random.length == 32) inputStream = ModernDecryptingPartInputStream.createFor(attachmentSecret, random, new File(data), 0);
else inputStream = ClassicDecryptingPartInputStream.createFor(attachmentSecret, new File(data));

EventBus.getDefault().post(new BackupEvent(BackupEvent.Type.PROGRESS, ++count));
outputStream.write(new AttachmentId(rowId, uniqueId), inputStream, size);
}
} catch (IOException e) {
Log.w(TAG, e);
}

return count;
}

private static void exportSticker(@NonNull AttachmentSecret attachmentSecret, @NonNull Cursor cursor, @NonNull BackupFrameOutputStream outputStream) {
private static int exportSticker(@NonNull AttachmentSecret attachmentSecret, @NonNull Cursor cursor, @NonNull BackupFrameOutputStream outputStream, int count) {
try {
long rowId = cursor.getLong(cursor.getColumnIndexOrThrow(StickerDatabase._ID));
long size = cursor.getLong(cursor.getColumnIndexOrThrow(StickerDatabase.FILE_LENGTH));
Expand All @@ -317,12 +323,15 @@ private static void exportSticker(@NonNull AttachmentSecret attachmentSecret, @N
byte[] random = cursor.getBlob(cursor.getColumnIndexOrThrow(StickerDatabase.FILE_RANDOM));

if (!TextUtils.isEmpty(data) && size > 0) {
EventBus.getDefault().post(new BackupEvent(BackupEvent.Type.PROGRESS, ++count));
InputStream inputStream = ModernDecryptingPartInputStream.createFor(attachmentSecret, random, new File(data), 0);
outputStream.writeSticker(rowId, inputStream, size);
}
} catch (IOException e) {
Log.w(TAG, e);
}

return count;
}

private static long calculateVeryOldStreamLength(@NonNull AttachmentSecret attachmentSecret, @Nullable byte[] random, @NonNull String data) throws IOException {
Expand Down Expand Up @@ -528,6 +537,10 @@ public void close() throws IOException {
}
}

public interface PostProcessor {
int postProcess(@NonNull Cursor cursor, int count);
}

public interface BackupCancellationSignal {
boolean isCanceled();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ public static void importFile(@NonNull Context context, @NonNull AttachmentSecre
BackupFrame frame;

while (!(frame = inputStream.readFrame()).getEnd()) {
if (count++ % 100 == 0) EventBus.getDefault().post(new BackupEvent(BackupEvent.Type.PROGRESS, count));
if (count % 100 == 0) EventBus.getDefault().post(new BackupEvent(BackupEvent.Type.PROGRESS, count));
count++;

if (frame.hasVersion()) processVersion(db, frame.getVersion());
else if (frame.hasStatement()) processStatement(db, frame.getStatement());
else if (frame.hasPreference()) processPreference(context, frame.getPreference());
else if (frame.hasAttachment()) processAttachment(context, attachmentSecret, db, frame.getAttachment(), inputStream);
else if (frame.hasSticker()) processSticker(context, attachmentSecret, db, frame.getSticker(), inputStream);
else if (frame.hasAvatar()) processAvatar(context, db, frame.getAvatar(), inputStream);
else count--;
}

db.setTransactionSuccessful();
Expand Down

0 comments on commit 9366596

Please sign in to comment.