Skip to content

Commit

Permalink
Fix bugs in HMSS implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh committed May 8, 2024
1 parent 652228f commit 7afb8f9
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ class Herald(
private val internalComputationsClient: ComputationsCoroutineStub,
private val systemComputationsClient: SystemComputationsCoroutineStub,
private val systemComputationParticipantClient: SystemComputationParticipantsCoroutineStub,
private val privateKeyStore: PrivateKeyStore<TinkKeyId, TinkPrivateKeyHandle>,
private val continuationTokenManager: ContinuationTokenManager,
private val protocolsSetupConfig: ProtocolsSetupConfig,
private val clock: Clock,
private val privateKeyStore: PrivateKeyStore<TinkKeyId, TinkPrivateKeyHandle>? = null,
private val blobStorageBucket: String = "computation-blob-storage",
private val maxAttempts: Int = 5,
private val maxStreamingAttempts: Int = 5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ object HonestMajorityShareShuffleStarter {
suspend fun createComputation(
computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub,
systemComputation: Computation,
honestMajorityShareShuffleSetupConfig: HonestMajorityShareShuffleSetupConfig,
protocolSetupConfig: HonestMajorityShareShuffleSetupConfig,
blobStorageBucket: String,
privateKeyStore: PrivateKeyStore<TinkKeyId, TinkPrivateKeyHandle>,
privateKeyStore: PrivateKeyStore<TinkKeyId, TinkPrivateKeyHandle>? = null,
) {
require(systemComputation.name.isNotEmpty()) { "Resource name not specified" }
val globalId: String = systemComputation.key.computationId
val role = honestMajorityShareShuffleSetupConfig.role
val role = protocolSetupConfig.role

val initialComputationDetails = computationDetails {
blobsStoragePrefix = "$blobStorageBucket/$globalId"
Expand All @@ -76,11 +76,12 @@ object HonestMajorityShareShuffleStarter {
HonestMajorityShareShuffleKt.computationDetails {
this.role = role
parameters = systemComputation.toHonestMajorityShareShuffleParameters()
participants += systemComputation.computationParticipantsList.map { it.key.duchyId }
nonAggregators += getNonAggregators(protocolSetupConfig)
if (role != RoleInComputation.AGGREGATOR) {
randomSeed = generateRandomSeed()

val privateKeyHandle = TinkPrivateKeyHandle.generateHpke()
requireNotNull(privateKeyStore) { "privateKeyStore cannot be null" }
val privateKeyId = storePrivateKey(privateKeyStore, privateKeyHandle)
encryptionKeyPair = encryptionKeyPair {
this.privateKeyId = privateKeyId
Expand Down Expand Up @@ -197,6 +198,10 @@ object HonestMajorityShareShuffleStarter {
}
}

private fun getNonAggregators(setupConfig: HonestMajorityShareShuffleSetupConfig): List<String> {
return listOf(setupConfig.firstNonAggregatorDuchyId, setupConfig.secondNonAggregatorDuchyId)
}

private fun generateRandomSeed(): ByteString {
val secureRandom = SecureRandom()
return secureRandom.generateSeed(RANDOM_SEED_LENGTH_IN_BYTES).toByteString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/duchy/db/computation",
"//src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha:advance_computation_request_headers",
"//src/main/kotlin/org/wfanet/measurement/system/v1alpha:resource_key",
"//src/main/proto/wfa/any_sketch:frequency_vector_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc",
"//src/main/proto/wfa/measurement/internal/duchy:crypto_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_kt_jvm_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ import java.time.Clock
import java.time.Duration
import java.util.logging.Logger
import kotlinx.coroutines.flow.flowOf
import org.wfanet.frequencycount.FrequencyVector
import org.wfanet.measurement.api.Version
import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt
import org.wfanet.measurement.api.v2alpha.EncryptedMessage
import org.wfanet.measurement.api.v2alpha.MeasurementSpec
import org.wfanet.measurement.api.v2alpha.RandomSeed
import org.wfanet.measurement.api.v2alpha.SignedMessage
import org.wfanet.measurement.api.v2alpha.encryptedMessage
import org.wfanet.measurement.api.v2alpha.getCertificateRequest
import org.wfanet.measurement.api.v2alpha.unpack
import org.wfanet.measurement.common.ProtoReflection
import org.wfanet.measurement.common.crypto.PrivateKeyStore
import org.wfanet.measurement.common.crypto.SigningKeyHandle
import org.wfanet.measurement.common.crypto.authorityKeyIdentifier
Expand All @@ -50,12 +53,15 @@ import org.wfanet.measurement.duchy.daemon.utils.ReachAndFrequencyResult
import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaEncryptionPublicKey
import org.wfanet.measurement.duchy.db.computation.ComputationDataClients
import org.wfanet.measurement.duchy.db.computation.ComputationDataClients.PermanentErrorException
import org.wfanet.measurement.duchy.service.internal.computations.inputPathList
import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader
import org.wfanet.measurement.duchy.toProtocolStage
import org.wfanet.measurement.internal.duchy.ComputationDetails
import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.config.HonestMajorityShareShuffleSetupConfig
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.FIRST_NON_AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.SECOND_NON_AGGREGATOR
Expand All @@ -70,9 +76,9 @@ import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.ShufflePhaseInputKt
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.aggregationPhaseInput
import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.shufflePhaseInput
import org.wfanet.measurement.internal.duchy.protocol.ShareShuffleSketch
import org.wfanet.measurement.internal.duchy.protocol.completeAggregationPhaseRequest
import org.wfanet.measurement.internal.duchy.protocol.completeShufflePhaseRequest
import org.wfanet.measurement.internal.duchy.protocol.copy
import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub
import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt
import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey
Expand All @@ -93,12 +99,13 @@ class HonestMajorityShareShuffleMill(
systemComputationsClient: ComputationsGrpcKt.ComputationsCoroutineStub,
systemComputationLogEntriesClient: ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub,
computationStatsClient: ComputationStatsGrpcKt.ComputationStatsCoroutineStub,
private val privateKeyStore: PrivateKeyStore<TinkKeyId, TinkPrivateKeyHandle>,
private val certificateClient: CertificatesGrpcKt.CertificatesCoroutineStub,
private val workerStubs: Map<String, ComputationControlCoroutineStub>,
private val cryptoWorker: HonestMajorityShareShuffleCryptor,
private val protocolSetupConfig: HonestMajorityShareShuffleSetupConfig,
workLockDuration: Duration,
openTelemetry: OpenTelemetry,
private val privateKeyStore: PrivateKeyStore<TinkKeyId, TinkPrivateKeyHandle>? = null,
requestChunkSizeBytes: Int = 1024 * 32,
maximumAttempts: Int = 10,
clock: Clock = Clock.systemUTC(),
Expand Down Expand Up @@ -216,21 +223,29 @@ class HonestMajorityShareShuffleMill(
)
}

private fun peerDuchyStub(participants: List<String>): ComputationControlCoroutineStub {
// TODO(@renjiez): remove participants fields and use a map to get stub for specific
// RoleInComputation.

// The last participant is the aggregator.
val peerDuchy = participants.dropLast(1).find { it != duchyId }
private fun peerDuchyStub(role: RoleInComputation): ComputationControlCoroutineStub {
val peerDuchy =
when (role) {
FIRST_NON_AGGREGATOR -> {
protocolSetupConfig.secondNonAggregatorDuchyId
}
SECOND_NON_AGGREGATOR -> {
protocolSetupConfig.firstNonAggregatorDuchyId
}
AGGREGATOR,
RoleInComputation.NON_AGGREGATOR,
RoleInComputation.ROLE_IN_COMPUTATION_UNSPECIFIED,
RoleInComputation.UNRECOGNIZED -> error("Unexpected role:$role for peerDuchyStub")
}
return workerStubs[peerDuchy]
?: throw PermanentErrorException(
"No ComputationControlService stub for the peer duchy '$peerDuchy'"
)
}

private fun aggregatorStub(participants: List<String>): ComputationControlCoroutineStub {
private fun aggregatorStub(): ComputationControlCoroutineStub {
// The last participant is the aggregator.
val aggregator = participants.last()
val aggregator = protocolSetupConfig.aggregatorDuchyId
return workerStubs[aggregator]
?: throw PermanentErrorException(
"No ComputationControlService stub for aggregator '$aggregator'"
Expand Down Expand Up @@ -271,12 +286,21 @@ class HonestMajorityShareShuffleMill(
sendAdvanceComputationRequest(
header = advanceComputationHeader(headerDescription, token.globalComputationId),
content = addLoggingHook(token, flowOf(shufflePhaseInput.toByteString())),
stub = peerDuchyStub(hmssDetails.participantsList),
stub = peerDuchyStub(hmssDetails.role),
)

val nextStage = nextStage(token).toProtocolStage()
return dataClients.transitionComputationToStage(
token,
stage = nextStage(token).toProtocolStage(),
// For worker2, the input of SETUP_PHASE is the ShufflePhaseInput from peer worker. It should
// be forwarded to SHUFFLE_PHASE.
inputsToNextStage =
if (nextStage == Stage.SHUFFLE_PHASE.toProtocolStage()) {
token.inputPathList()
} else {
emptyList()
},
)
}

Expand All @@ -301,17 +325,25 @@ class HonestMajorityShareShuffleMill(
private suspend fun verifySecretSeed(
secretSeed: ShufflePhaseInput.SecretSeed,
duchyPrivateKeyId: String,
apiVersion: Version,
): RandomSeed {
requireNotNull(privateKeyStore) { "privateKeyStore is null for non-aggregator." }
val privateKey =
privateKeyStore.read(TinkKeyId(duchyPrivateKeyId.toLong()))
?: throw PermanentErrorException(
"Fail to get private key for requisition ${secretSeed.requisitionId}"
)

val encryptedAndSignedMessage = EncryptedMessage.parseFrom(secretSeed.secretSeedCiphertext)
val signedMessage = decryptRandomSeed(encryptedAndSignedMessage, privateKey)

val randomSeed: RandomSeed = signedMessage.unpack()
val encryptedAndSignedSeed =
when (apiVersion) {
Version.V2_ALPHA -> {
encryptedMessage {
ciphertext = secretSeed.secretSeedCiphertext
typeUrl = ProtoReflection.getTypeUrl(SignedMessage.getDescriptor())
}
}
}
val signedSeed = decryptRandomSeed(encryptedAndSignedSeed, privateKey)

val dataProviderCertificateName = secretSeed.dataProviderCertificate
val dataProviderCertificate =
Expand All @@ -331,19 +363,18 @@ class HonestMajorityShareShuffleMill(
)

try {
verifyRandomSeed(signedMessage, x509Certificate, trustedIssuer)
verifyRandomSeed(signedSeed, x509Certificate, trustedIssuer)
} catch (e: CertPathValidatorException) {
throw PermanentErrorException("Invalid certificate for $dataProviderCertificateName", e)
} catch (e: SignatureException) {
throw PermanentErrorException("Signature fails verification.", e)
}

return randomSeed
return signedSeed.unpack()
}

private suspend fun shufflePhase(token: ComputationToken): ComputationToken {
val requisitions = token.requisitionsList
requisitions.sortedBy { it.externalKey.externalRequisitionId }
val requisitions = token.requisitionsList.sortedBy { it.externalKey.externalRequisitionId }

val requisitionBlobs = dataClients.readRequisitionBlobs(token)
val shufflePhaseInput = getShufflePhaseInput(token)
Expand All @@ -365,31 +396,38 @@ class HonestMajorityShareShuffleMill(
}
noiseMechanism = hmss.parameters.noiseMechanism

val registerCounts = mutableListOf<Long>()
for (requisition in requisitions) {
val requisitionId = requisition.externalKey.externalRequisitionId

val blob = requisitionBlobs[requisitionId]
if (blob != null) {
// Requisition in format of blob.
registerCounts += requisition.details.honestMajorityShareShuffle.registerCount
sketchShares += sketchShare {
data =
CompleteShufflePhaseRequestKt.SketchShareKt.shareData {
// TODO(@renjiez): Use ShareShuffleSketch from any-sketch-java when it is available.
values += ShareShuffleSketch.parseFrom(blob).dataList
values += FrequencyVector.parseFrom(blob).dataList
}
}
} else {
// Requisition in format of random seed.
val secretSeed = secretSeeds.find { it.requisitionId == requisitionId }
require(secretSeed != null) {
"Neither blob and seed received for requisition $requisitionId"
}
val secretSeed =
secretSeeds.find { it.requisitionId == requisitionId }
?: error("Neither blob and seed received for requisition $requisitionId")

val seed = verifySecretSeed(secretSeed, hmss.encryptionKeyPair.privateKeyId)
val publicApiVersion =
Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)
val seed =
verifySecretSeed(secretSeed, hmss.encryptionKeyPair.privateKeyId, publicApiVersion)

sketchShares += sketchShare { this.seed = seed.data }
}
}
require(registerCounts.distinct().size == 1) {
"All RegisterCount from requisitions must be the same. $registerCounts"
}
sketchParams = hmss.parameters.sketchParams.copy { registerCount = registerCounts.first() }
}

val result = cryptoWorker.completeShufflePhase(request)
Expand All @@ -402,7 +440,7 @@ class HonestMajorityShareShuffleMill(
header =
advanceComputationHeader(Description.AGGREGATION_PHASE_INPUT, token.globalComputationId),
content = addLoggingHook(token, flowOf(aggregationPhaseInput.toByteString())),
stub = aggregatorStub(token.computationDetails.honestMajorityShareShuffle.participantsList),
stub = aggregatorStub(),
)

return completeComputation(token, ComputationDetails.CompletedReason.SUCCEEDED)
Expand Down Expand Up @@ -453,6 +491,7 @@ class HonestMajorityShareShuffleMill(
mapOf(
Pair(Stage.INITIALIZED, FIRST_NON_AGGREGATOR) to Stage.WAIT_TO_START,
Pair(Stage.INITIALIZED, SECOND_NON_AGGREGATOR) to Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE,
Pair(Stage.INITIALIZED, AGGREGATOR) to Stage.WAIT_ON_AGGREGATION_INPUT,
Pair(Stage.SETUP_PHASE, FIRST_NON_AGGREGATOR) to Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO,
Pair(Stage.SETUP_PHASE, SECOND_NON_AGGREGATOR) to Stage.SHUFFLE_PHASE,
Pair(Stage.SHUFFLE_PHASE, FIRST_NON_AGGREGATOR) to Stage.COMPLETE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ object HonestMajorityShareShuffleProtocol {
computationStageDetails {
honestMajorityShareShuffle = stageDetails {
waitOnAggregationInputDetails = waitOnAggregationInputDetails {
val participants = computationDetails.participantsList
val nonAggregators = participants.subList(0, participants.size - 1)
nonAggregators.mapIndexed { idx, duchyId ->
computationDetails.nonAggregatorsList.mapIndexed { idx, duchyId ->
externalDuchyLocalBlobId[duchyId] = idx.toLong()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.wfanet.measurement.duchy.service.internal.computations
import org.wfanet.measurement.internal.duchy.AdvanceComputationStageResponse
import org.wfanet.measurement.internal.duchy.ClaimWorkResponse
import org.wfanet.measurement.internal.duchy.ComputationBlobDependency
import org.wfanet.measurement.internal.duchy.ComputationBlobDependency.INPUT
import org.wfanet.measurement.internal.duchy.ComputationBlobDependency.OUTPUT
import org.wfanet.measurement.internal.duchy.ComputationBlobDependency.PASS_THROUGH
import org.wfanet.measurement.internal.duchy.ComputationStageBlobMetadata
Expand Down Expand Up @@ -70,6 +71,10 @@ fun ComputationToken.outputPathList(): List<String> =
.filter { it.dependencyType == OUTPUT || it.dependencyType == PASS_THROUGH }
.map { it.path }

/** Extract the list of input blob paths from a [ComputationToken]. */
fun ComputationToken.inputPathList(): List<String> =
this.blobsList.filter { it.dependencyType == INPUT }.map { it.path }

/** Creates a [ComputationStageBlobMetadata] for an input blob. */
fun newInputBlobMetadata(id: Long, key: String): ComputationStageBlobMetadata =
ComputationStageBlobMetadata.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ fun buildMpcProtocolConfig(
mpcProtocolConfig {
honestMajorityShareShuffle = honestMajorityShareShuffle {
sketchParams = shareShuffleSketchParams {
registerCount = protocolConfig.honestMajorityShareShuffle.sketchParams.registerCount
bytesPerRegister =
protocolConfig.honestMajorityShareShuffle.sketchParams.bytesPerRegister
ringModulus = protocolConfig.honestMajorityShareShuffle.sketchParams.ringModulus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,12 @@ message HonestMajorityShareShuffleSetupConfig {
// Role of duchy in Computations.
RoleInComputation role = 1;

// The external id of the aggregator duchy.
string external_aggregator_duchy_id = 2;
// The external id of the first non-aggregator duchy.
string first_non_aggregator_duchy_id = 2;

// The external id of the peer non-aggregator duchy.
string external_peer_non_aggregator_duchy_id = 3;
// The external id of the second non-aggregator duchy.
string second_non_aggregator_duchy_id = 3;

// The external id of the aggregator duchy.
string aggregator_duchy_id = 4;
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,8 @@ message HonestMajorityShareShuffle {
}
Parameters parameters = 2;

// The list of ids of duchies participating in this computation.
// The list is sorted by the duchy order by names, with the first element
// being the first non-aggregator following by other non-aggregators,
// and the last element being the aggregator.
repeated string participants = 3;
// The list of ids of non-aggregators participating in this computation.
repeated string non_aggregators = 3;

// Seed generated by this worker and will be used to generate noise and
// permutation.
Expand Down
Loading

0 comments on commit 7afb8f9

Please sign in to comment.