From 49a4b2a8d3be2496db0b9c789f5c48bb4574ef05 Mon Sep 17 00:00:00 2001 From: Willian Wang Date: Thu, 31 Aug 2023 00:46:39 -0300 Subject: [PATCH] Move parallelization to rust This should minimize the overhead and add parallelization to web. --- lib/findMy/decrypt_reports.dart | 26 +++++----- lib/findMy/find_my_controller.dart | 25 ++------- native/Cargo.lock | 81 ++++++++++++++++++++++++++++++ native/Cargo.toml | 1 + native/src/api.rs | 29 +++++------ 5 files changed, 115 insertions(+), 47 deletions(-) diff --git a/lib/findMy/decrypt_reports.dart b/lib/findMy/decrypt_reports.dart index 6d551f0..746adbc 100644 --- a/lib/findMy/decrypt_reports.dart +++ b/lib/findMy/decrypt_reports.dart @@ -12,34 +12,34 @@ import 'package:openhaystack_mobile/ffi/ffi.dart' class DecryptReports { /// Decrypts a given [FindMyReport] with the given private key. - static Future> decryptReportChunk(List reportChunk, Uint8List privateKeyBytes) async { + static Future> decryptReports(List reports, Uint8List privateKeyBytes) async { final curveDomainParam = ECCurve_secp224r1(); - final ephemeralKeyChunk = reportChunk.map((report) { + final ephemeralKeys = reports.map((report) { final payloadData = report.payload; final ephemeralKeyBytes = payloadData.sublist(5, 62); return ephemeralKeyBytes; }).toList(); - late final List sharedKeyChunk; + late final List sharedKeys; try { debugPrint("Trying native ECDH"); - final ephemeralKeyBlob = Uint8List.fromList(ephemeralKeyChunk.expand((element) => element).toList()); + final ephemeralKeyBlob = Uint8List.fromList(ephemeralKeys.expand((element) => element).toList()); final sharedKeyBlob = await api.ecdh(publicKeyBlob: ephemeralKeyBlob, privateKey: privateKeyBytes); - final chunkSize = (sharedKeyBlob.length / ephemeralKeyChunk.length).ceil(); - sharedKeyChunk = [ - for (var i = 0; i < sharedKeyBlob.length; i += chunkSize) - sharedKeyBlob.sublist(i, i + chunkSize < sharedKeyBlob.length ? i + chunkSize : sharedKeyBlob.length), + final keySize = (sharedKeyBlob.length / ephemeralKeys.length).ceil(); + sharedKeys = [ + for (var i = 0; i < sharedKeyBlob.length; i += keySize) + sharedKeyBlob.sublist(i, i + keySize < sharedKeyBlob.length ? i + keySize : sharedKeyBlob.length), ]; } catch (e) { debugPrint("Native ECDH failed: $e"); - debugPrint("Falling back to pure Dart ECDH!"); + debugPrint("Falling back to pure Dart ECDH on single thread!"); final privateKey = ECPrivateKey( pc_utils.decodeBigIntWithSign(1, privateKeyBytes), curveDomainParam); - sharedKeyChunk = ephemeralKeyChunk.map((ephemeralKey) { + sharedKeys = ephemeralKeys.map((ephemeralKey) { final decodePoint = curveDomainParam.curve.decodePoint(ephemeralKey); final ephemeralPublicKey = ECPublicKey(decodePoint, curveDomainParam); @@ -48,8 +48,8 @@ class DecryptReports { }).toList(); } - final decryptedLocationChunk = reportChunk.mapIndexed((index, report) { - final derivedKey = _kdf(sharedKeyChunk[index], ephemeralKeyChunk[index]); + final decryptedLocations = reports.mapIndexed((index, report) { + final derivedKey = _kdf(sharedKeys[index], ephemeralKeys[index]); final payloadData = report.payload; _decodeTimeAndConfidence(payloadData, report); final encData = payloadData.sublist(62, 72); @@ -59,7 +59,7 @@ class DecryptReports { return locationReport; }).toList(); - return decryptedLocationChunk; + return decryptedLocations; } /// Decodes the unencrypted timestamp and confidence diff --git a/lib/findMy/find_my_controller.dart b/lib/findMy/find_my_controller.dart index b583c70..0f85fc5 100644 --- a/lib/findMy/find_my_controller.dart +++ b/lib/findMy/find_my_controller.dart @@ -2,7 +2,6 @@ import 'dart:collection'; import 'dart:convert'; import 'dart:isolate'; import 'dart:typed_data'; -import 'dart:io' as IO; import 'package:flutter/foundation.dart'; import 'package:flutter_secure_storage/flutter_secure_storage.dart'; @@ -38,18 +37,8 @@ class FindMyController { FindMyKeyPair keyPair = args[0]; String seemooEndpoint = args[1]; final jsonReports = await ReportsFetcher.fetchLocationReports(keyPair.getHashedAdvertisementKey(), seemooEndpoint); - final numChunks = kIsWeb ? 1 : IO.Platform.numberOfProcessors+1; - final chunkSize = (jsonReports.length / numChunks).ceil(); - final chunks = [ - for (var i = 0; i < jsonReports.length; i += chunkSize) - jsonReports.sublist(i, i + chunkSize < jsonReports.length ? i + chunkSize : jsonReports.length), - ]; - final decryptedLocations = await Future.wait(chunks.map((jsonChunk) async { - final decryptedChunk = await compute(_decryptChunk, [jsonChunk, keyPair, keyPair.privateKeyBase64!]); - return decryptedChunk; - })); - final results = decryptedLocations.expand((element) => element).toList(); - return results; + final decryptedLocations = await _decryptReports(jsonReports, keyPair, keyPair.privateKeyBase64!); + return decryptedLocations; } /// Loads the private key from the local cache or secure storage and adds it @@ -77,12 +66,8 @@ class FindMyController { /// Decrypts the encrypted reports with the given list of [FindMyKeyPair] and private key. /// Returns the list of decrypted reports as a list of [FindMyLocationReport]. - static Future> _decryptChunk(List args) async { - List jsonChunk = args[0]; - FindMyKeyPair keyPair = args[1]; - String privateKey = args[2]; - - final reportChunk = jsonChunk.map((jsonReport) { + static Future> _decryptReports(List jsonRerportList, FindMyKeyPair keyPair, String privateKey) async { + final reportChunk = jsonRerportList.map((jsonReport) { assert (jsonReport["id"]! == keyPair.getHashedAdvertisementKey(), "Returned FindMyReport hashed key != requested hashed key"); @@ -98,7 +83,7 @@ class FindMyController { return report; }).toList(); - final decryptedReports = await DecryptReports.decryptReportChunk(reportChunk, base64Decode(privateKey)); + final decryptedReports = await DecryptReports.decryptReports(reportChunk, base64Decode(privateKey)); return decryptedReports; } diff --git a/native/Cargo.lock b/native/Cargo.lock index c7223f6..a1afbe4 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -152,6 +152,49 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + [[package]] name = "crypto-bigint" version = "0.5.2" @@ -219,6 +262,12 @@ dependencies = [ "signature", ] +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + [[package]] name = "elliptic-curve" version = "0.13.5" @@ -400,6 +449,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "native" version = "0.1.0" @@ -407,6 +465,7 @@ dependencies = [ "flutter_rust_bridge", "getrandom", "p224", + "rayon", ] [[package]] @@ -524,6 +583,28 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.3.5" diff --git a/native/Cargo.toml b/native/Cargo.toml index 4b18e93..5d57ce5 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -10,6 +10,7 @@ crate-type = ["staticlib", "cdylib", "rlib"] flutter_rust_bridge = "^1.77.0" p224 = "^0.13.2" getrandom = "^0.2.9" +rayon = "1.7.0" [features] default = ["p224/ecdh", "getrandom/js"] diff --git a/native/src/api.rs b/native/src/api.rs index 67c8d1c..dbc6764 100644 --- a/native/src/api.rs +++ b/native/src/api.rs @@ -1,31 +1,32 @@ use p224::{SecretKey, PublicKey, ecdh::diffie_hellman}; +use rayon::prelude::*; +use std::sync::{Arc, Mutex}; const PRIVATE_LEN : usize = 28; const PUBLIC_LEN : usize = 57; pub fn ecdh(public_key_blob : Vec, private_key : Vec) -> Vec { let num_keys = public_key_blob.len() / PUBLIC_LEN; - let mut vec_shared_secret = vec![0u8; num_keys*PRIVATE_LEN]; + let vec_shared_secret = Arc::new(Mutex::new(vec![0u8; num_keys*PRIVATE_LEN])); let private_key = SecretKey::from_slice(&private_key).unwrap(); let secret_scalar = private_key.to_nonzero_scalar(); - - let mut i = 0; - let mut j = 0; - for _i in 0..num_keys { - let public_key = PublicKey::from_sec1_bytes(&public_key_blob[i..i+PUBLIC_LEN]).unwrap(); + (0..num_keys).into_par_iter().for_each(|i| { + let start = i * PUBLIC_LEN; + let end = start + PUBLIC_LEN; + let public_key = PublicKey::from_sec1_bytes(&public_key_blob[start..end]).unwrap(); let public_affine = public_key.as_affine(); - - let shared_secret = diffie_hellman(secret_scalar, public_affine); + + let shared_secret = diffie_hellman(secret_scalar, public_affine); let shared_secret_ref = shared_secret.raw_secret_bytes().as_ref(); + let start = i * PRIVATE_LEN; + let end = start + PRIVATE_LEN; - vec_shared_secret[j..j+PRIVATE_LEN].copy_from_slice(shared_secret_ref); + let mut vec_shared_secret = vec_shared_secret.lock().unwrap(); + vec_shared_secret[start..end].copy_from_slice(shared_secret_ref); + }); - i += PUBLIC_LEN; - j += PRIVATE_LEN; - } - - return vec_shared_secret; + Arc::try_unwrap(vec_shared_secret).unwrap().into_inner().unwrap() }