Skip to content

Commit

Permalink
Move parallelization to rust
Browse files Browse the repository at this point in the history
This should minimize the overhead and add parallelization to web.
  • Loading branch information
wangwillian0 committed Aug 31, 2023
1 parent b17324d commit 49a4b2a
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 47 deletions.
26 changes: 13 additions & 13 deletions lib/findMy/decrypt_reports.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,34 @@ import 'package:openhaystack_mobile/ffi/ffi.dart'

class DecryptReports {
/// Decrypts a given [FindMyReport] with the given private key.
static Future<List<FindMyLocationReport>> decryptReportChunk(List<FindMyReport> reportChunk, Uint8List privateKeyBytes) async {
static Future<List<FindMyLocationReport>> decryptReports(List<FindMyReport> reports, Uint8List privateKeyBytes) async {
final curveDomainParam = ECCurve_secp224r1();

final ephemeralKeyChunk = reportChunk.map((report) {
final ephemeralKeys = reports.map((report) {
final payloadData = report.payload;
final ephemeralKeyBytes = payloadData.sublist(5, 62);
return ephemeralKeyBytes;
}).toList();

late final List<Uint8List> sharedKeyChunk;
late final List<Uint8List> sharedKeys;

try {
debugPrint("Trying native ECDH");
final ephemeralKeyBlob = Uint8List.fromList(ephemeralKeyChunk.expand((element) => element).toList());
final ephemeralKeyBlob = Uint8List.fromList(ephemeralKeys.expand((element) => element).toList());
final sharedKeyBlob = await api.ecdh(publicKeyBlob: ephemeralKeyBlob, privateKey: privateKeyBytes);
final chunkSize = (sharedKeyBlob.length / ephemeralKeyChunk.length).ceil();
sharedKeyChunk = [
for (var i = 0; i < sharedKeyBlob.length; i += chunkSize)
sharedKeyBlob.sublist(i, i + chunkSize < sharedKeyBlob.length ? i + chunkSize : sharedKeyBlob.length),
final keySize = (sharedKeyBlob.length / ephemeralKeys.length).ceil();
sharedKeys = [
for (var i = 0; i < sharedKeyBlob.length; i += keySize)
sharedKeyBlob.sublist(i, i + keySize < sharedKeyBlob.length ? i + keySize : sharedKeyBlob.length),
];
}
catch (e) {
debugPrint("Native ECDH failed: $e");
debugPrint("Falling back to pure Dart ECDH!");
debugPrint("Falling back to pure Dart ECDH on single thread!");
final privateKey = ECPrivateKey(
pc_utils.decodeBigIntWithSign(1, privateKeyBytes),
curveDomainParam);
sharedKeyChunk = ephemeralKeyChunk.map((ephemeralKey) {
sharedKeys = ephemeralKeys.map((ephemeralKey) {
final decodePoint = curveDomainParam.curve.decodePoint(ephemeralKey);
final ephemeralPublicKey = ECPublicKey(decodePoint, curveDomainParam);

Expand All @@ -48,8 +48,8 @@ class DecryptReports {
}).toList();
}

final decryptedLocationChunk = reportChunk.mapIndexed((index, report) {
final derivedKey = _kdf(sharedKeyChunk[index], ephemeralKeyChunk[index]);
final decryptedLocations = reports.mapIndexed((index, report) {
final derivedKey = _kdf(sharedKeys[index], ephemeralKeys[index]);
final payloadData = report.payload;
_decodeTimeAndConfidence(payloadData, report);
final encData = payloadData.sublist(62, 72);
Expand All @@ -59,7 +59,7 @@ class DecryptReports {
return locationReport;
}).toList();

return decryptedLocationChunk;
return decryptedLocations;
}

/// Decodes the unencrypted timestamp and confidence
Expand Down
25 changes: 5 additions & 20 deletions lib/findMy/find_my_controller.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import 'dart:collection';
import 'dart:convert';
import 'dart:isolate';
import 'dart:typed_data';
import 'dart:io' as IO;

import 'package:flutter/foundation.dart';
import 'package:flutter_secure_storage/flutter_secure_storage.dart';
Expand Down Expand Up @@ -38,18 +37,8 @@ class FindMyController {
FindMyKeyPair keyPair = args[0];
String seemooEndpoint = args[1];
final jsonReports = await ReportsFetcher.fetchLocationReports(keyPair.getHashedAdvertisementKey(), seemooEndpoint);
final numChunks = kIsWeb ? 1 : IO.Platform.numberOfProcessors+1;
final chunkSize = (jsonReports.length / numChunks).ceil();
final chunks = [
for (var i = 0; i < jsonReports.length; i += chunkSize)
jsonReports.sublist(i, i + chunkSize < jsonReports.length ? i + chunkSize : jsonReports.length),
];
final decryptedLocations = await Future.wait(chunks.map((jsonChunk) async {
final decryptedChunk = await compute(_decryptChunk, [jsonChunk, keyPair, keyPair.privateKeyBase64!]);
return decryptedChunk;
}));
final results = decryptedLocations.expand((element) => element).toList();
return results;
final decryptedLocations = await _decryptReports(jsonReports, keyPair, keyPair.privateKeyBase64!);
return decryptedLocations;
}

/// Loads the private key from the local cache or secure storage and adds it
Expand Down Expand Up @@ -77,12 +66,8 @@ class FindMyController {

/// Decrypts the encrypted reports with the given list of [FindMyKeyPair] and private key.
/// Returns the list of decrypted reports as a list of [FindMyLocationReport].
static Future<List<FindMyLocationReport>> _decryptChunk(List<dynamic> args) async {
List<dynamic> jsonChunk = args[0];
FindMyKeyPair keyPair = args[1];
String privateKey = args[2];

final reportChunk = jsonChunk.map((jsonReport) {
static Future<List<FindMyLocationReport>> _decryptReports(List<dynamic> jsonRerportList, FindMyKeyPair keyPair, String privateKey) async {
final reportChunk = jsonRerportList.map((jsonReport) {
assert (jsonReport["id"]! == keyPair.getHashedAdvertisementKey(),
"Returned FindMyReport hashed key != requested hashed key");

Expand All @@ -98,7 +83,7 @@ class FindMyController {
return report;
}).toList();

final decryptedReports = await DecryptReports.decryptReportChunk(reportChunk, base64Decode(privateKey));
final decryptedReports = await DecryptReports.decryptReports(reportChunk, base64Decode(privateKey));

return decryptedReports;
}
Expand Down
81 changes: 81 additions & 0 deletions native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ crate-type = ["staticlib", "cdylib", "rlib"]
flutter_rust_bridge = "^1.77.0"
p224 = "^0.13.2"
getrandom = "^0.2.9"
rayon = "1.7.0"

[features]
default = ["p224/ecdh", "getrandom/js"]
Expand Down
29 changes: 15 additions & 14 deletions native/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
use p224::{SecretKey, PublicKey, ecdh::diffie_hellman};
use rayon::prelude::*;
use std::sync::{Arc, Mutex};

const PRIVATE_LEN : usize = 28;
const PUBLIC_LEN : usize = 57;

pub fn ecdh(public_key_blob : Vec<u8>, private_key : Vec<u8>) -> Vec<u8> {
let num_keys = public_key_blob.len() / PUBLIC_LEN;
let mut vec_shared_secret = vec![0u8; num_keys*PRIVATE_LEN];
let vec_shared_secret = Arc::new(Mutex::new(vec![0u8; num_keys*PRIVATE_LEN]));

let private_key = SecretKey::from_slice(&private_key).unwrap();
let secret_scalar = private_key.to_nonzero_scalar();

let mut i = 0;
let mut j = 0;

for _i in 0..num_keys {
let public_key = PublicKey::from_sec1_bytes(&public_key_blob[i..i+PUBLIC_LEN]).unwrap();
(0..num_keys).into_par_iter().for_each(|i| {
let start = i * PUBLIC_LEN;
let end = start + PUBLIC_LEN;
let public_key = PublicKey::from_sec1_bytes(&public_key_blob[start..end]).unwrap();
let public_affine = public_key.as_affine();
let shared_secret = diffie_hellman(secret_scalar, public_affine);

let shared_secret = diffie_hellman(secret_scalar, public_affine);
let shared_secret_ref = shared_secret.raw_secret_bytes().as_ref();

let start = i * PRIVATE_LEN;
let end = start + PRIVATE_LEN;

vec_shared_secret[j..j+PRIVATE_LEN].copy_from_slice(shared_secret_ref);
let mut vec_shared_secret = vec_shared_secret.lock().unwrap();
vec_shared_secret[start..end].copy_from_slice(shared_secret_ref);
});

i += PUBLIC_LEN;
j += PRIVATE_LEN;
}

return vec_shared_secret;
Arc::try_unwrap(vec_shared_secret).unwrap().into_inner().unwrap()
}

0 comments on commit 49a4b2a

Please sign in to comment.