Skip to content

Commit

Permalink
feat: proof parallelization
Browse files Browse the repository at this point in the history
- [x] Prepare circuit inputs and run all proofs async
- [x] Minor optimization for MACI contract
  • Loading branch information
0xmad committed Jul 3, 2024
1 parent 89fef5a commit 7ebf165
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 27 deletions.
5 changes: 4 additions & 1 deletion contracts/contracts/MACI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ contract MACI is IMACI, DomainObjs, Params, Utilities {
/// if we change the state tree depth!
uint8 public immutable stateTreeDepth;

uint256 public immutable signUpsLimit;

uint8 internal constant TREE_ARITY = 2;
uint8 internal constant MESSAGE_TREE_ARITY = 5;

Expand Down Expand Up @@ -112,6 +114,7 @@ contract MACI is IMACI, DomainObjs, Params, Utilities {
signUpGatekeeper = _signUpGatekeeper;
initialVoiceCreditProxy = _initialVoiceCreditProxy;
stateTreeDepth = _stateTreeDepth;
signUpsLimit = uint256(TREE_ARITY) ** uint256(_stateTreeDepth);

// Verify linked poseidon libraries
if (hash2([uint256(1), uint256(1)]) == 0) revert PoseidonHashLibrariesNotLinked();
Expand All @@ -135,7 +138,7 @@ contract MACI is IMACI, DomainObjs, Params, Utilities {
bytes memory _initialVoiceCreditProxyData
) public virtual {
// ensure we do not have more signups than what the circuits support
if (lazyIMTData.numberOfLeaves >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups();
if (lazyIMTData.numberOfLeaves >= signUpsLimit) revert TooManySignups();

// ensure that the public key is on the baby jubjub curve
if (!CurveBabyJubJub.isOnCurve(_pubKey.x, _pubKey.y)) {
Expand Down
41 changes: 27 additions & 14 deletions contracts/tasks/helpers/ProofGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ export class ProofGenerator {
maciPrivateKey,
coordinatorKeypair,
signer,
outputDir,
options: { transactionHash, stateFile, startBlock, endBlock, blocksPerBatch },
}: IPrepareStateParams): Promise<MaciState> {
if (!fs.existsSync(path.resolve(outputDir))) {
await fs.promises.mkdir(path.resolve(outputDir));
}

if (stateFile) {
const content = JSON.parse(fs.readFileSync(stateFile).toString()) as unknown as IJsonMaciState;
const serializedPrivateKey = maciPrivateKey.serialize();
Expand Down Expand Up @@ -175,7 +180,6 @@ export class ProofGenerator {
performance.mark("mp-proofs-start");

console.log(`Generating proofs of message processing...`);
const proofs: Proof[] = [];
const { messageBatchSize } = this.poll.batchSizes;
const numMessages = this.poll.messages.length;
let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize);
Expand All @@ -184,6 +188,8 @@ export class ProofGenerator {
totalMessageBatches += 1;
}

const inputs: CircuitInputs[] = [];

// while we have unprocessed messages, process them
while (this.poll.hasUnprocessedMessages()) {
// process messages in batches
Expand All @@ -193,14 +199,19 @@ export class ProofGenerator {
) as unknown as CircuitInputs;

// generate the proof for this batch
// eslint-disable-next-line no-await-in-loop
await this.generateProofs(circuitInputs, this.mp, `process_${this.poll.numBatchesProcessed - 1}.json`).then(
(data) => proofs.push(...data),
);
inputs.push(circuitInputs);

console.log(`Progress: ${this.poll.numBatchesProcessed} / ${totalMessageBatches}`);
}

console.log("Wait until proof generation is finished");

const proofs = await Promise.all(
inputs.map((circuitInputs, index) => this.generateProofs(circuitInputs, this.mp, `process_${index}.json`)),
).then((data) => data.reduce((acc, x) => acc.concat(x), []));

console.log("Proof generation is finished");

performance.mark("mp-proofs-end");
performance.measure("Generate message processor proofs", "mp-proofs-start", "mp-proofs-end");

Expand All @@ -217,7 +228,6 @@ export class ProofGenerator {
performance.mark("tally-proofs-start");

console.log(`Generating proofs of vote tallying...`);
const proofs: Proof[] = [];
const { tallyBatchSize } = this.poll.batchSizes;
const numStateLeaves = this.poll.stateLeaves.length;
let totalTallyBatches = numStateLeaves <= tallyBatchSize ? 1 : Math.floor(numStateLeaves / tallyBatchSize);
Expand All @@ -226,19 +236,26 @@ export class ProofGenerator {
}

let tallyCircuitInputs: CircuitInputs;
const inputs: CircuitInputs[] = [];

while (this.poll.hasUntalliedBallots()) {
tallyCircuitInputs = (this.useQuadraticVoting
? this.poll.tallyVotes()
: this.poll.tallyVotesNonQv()) as unknown as CircuitInputs;

// eslint-disable-next-line no-await-in-loop
await this.generateProofs(tallyCircuitInputs, this.tally, `tally_${this.poll.numBatchesTallied - 1}.json`).then(
(data) => proofs.push(...data),
);
inputs.push(tallyCircuitInputs);

console.log(`Progress: ${this.poll.numBatchesTallied} / ${totalTallyBatches}`);
}

console.log("Wait until proof generation is finished");

const proofs = await Promise.all(
inputs.map((circuitInputs, index) => this.generateProofs(circuitInputs, this.tally, `tally_${index}.json`)),
).then((data) => data.reduce((acc, x) => acc.concat(x), []));

console.log("Proof generation is finished");

// verify the results
// Compute newResultsCommitment
const newResultsCommitment = genTreeCommitment(
Expand Down Expand Up @@ -359,10 +376,6 @@ export class ProofGenerator {
publicInputs: publicSignals,
});

if (!fs.existsSync(path.resolve(this.outputDir))) {
await fs.promises.mkdir(path.resolve(this.outputDir));
}

await fs.promises.writeFile(
path.resolve(this.outputDir, outputFile),
JSON.stringify(proofs[proofs.length - 1], null, 4),
Expand Down
5 changes: 5 additions & 0 deletions contracts/tasks/helpers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ export interface IPrepareStateParams {
*/
signer: Signer;

/**
* The directory to store the proofs
*/
outputDir: string;

/**
* Options for state (on-chain fetching or local file)
*/
Expand Down
1 change: 1 addition & 0 deletions contracts/tasks/runner/prove.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ task("prove", "Command to generate proof and prove the result of a poll on-chain
coordinatorKeypair,
pollId: poll,
signer,
outputDir,
options: {
stateFile,
transactionHash,
Expand Down
23 changes: 12 additions & 11 deletions coordinator/ts/proof/proof.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,31 @@ export class ProofGeneratorService {
address: maciContractAddress,
});

const signer = await this.deployment.getDeployer();
const pollAddress = await maciContract.polls(poll);
const [signer, pollAddress] = await Promise.all([this.deployment.getDeployer(), maciContract.polls(poll)]);

if (pollAddress.toLowerCase() === ZeroAddress.toLowerCase()) {
this.logger.error(`Error: ${ErrorCodes.POLL_NOT_FOUND}, Poll ${poll} not found`);
throw new Error(ErrorCodes.POLL_NOT_FOUND);
}

const pollContract = await this.deployment.getContract<Poll>({ name: EContracts.Poll, address: pollAddress });
const [{ messageAq: messageAqAddress }, coordinatorPublicKey] = await Promise.all([
pollContract.extContracts(),
pollContract.coordinatorPubKey(),
]);
const [{ messageAq: messageAqAddress }, coordinatorPublicKey, isStateAqMerged, messageTreeDepth] =
await Promise.all([
pollContract.extContracts(),
pollContract.coordinatorPubKey(),
pollContract.stateMerged(),
pollContract.treeDepths().then((depths) => Number(depths[2])),
]);
const messageAq = await this.deployment.getContract<AccQueue>({
name: EContracts.AccQueue,
address: messageAqAddress,
});

const isStateAqMerged = await pollContract.stateMerged();

if (!isStateAqMerged) {
this.logger.error(`Error: ${ErrorCodes.NOT_MERGED_STATE_TREE}, state tree is not merged`);
throw new Error(ErrorCodes.NOT_MERGED_STATE_TREE);
}

const messageTreeDepth = await pollContract.treeDepths().then((depths) => Number(depths[2]));

const mainRoot = await messageAq.getMainRoot(messageTreeDepth.toString());

if (mainRoot.toString() === "0") {
Expand All @@ -108,6 +106,8 @@ export class ProofGeneratorService {
throw new Error(ErrorCodes.PRIVATE_KEY_MISMATCH);
}

const outputDir = path.resolve("./proofs");

const maciState = await ProofGenerator.prepareState({
maciContract,
pollContract,
Expand All @@ -116,6 +116,7 @@ export class ProofGeneratorService {
coordinatorKeypair,
pollId: poll,
signer,
outputDir,
options: {
startBlock,
endBlock,
Expand All @@ -137,7 +138,7 @@ export class ProofGeneratorService {
tally: this.fileService.getZkeyFilePaths(process.env.COORDINATOR_TALLY_ZKEY_NAME!, useQuadraticVoting),
mp: this.fileService.getZkeyFilePaths(process.env.COORDINATOR_MESSAGE_PROCESS_ZKEY_NAME!, useQuadraticVoting),
rapidsnark: process.env.COORDINATOR_RAPIDSNARK_EXE,
outputDir: path.resolve("./proofs"),
outputDir,
tallyOutputFile: path.resolve("./tally.json"),
useQuadraticVoting,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function signUp(
bytes memory _initialVoiceCreditProxyData
) public virtual {
// ensure we do not have more signups than what the circuits support
if (lazyIMTData.numberOfLeaves >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups();
if (lazyIMTData.numberOfLeaves >= signUpsLimit) revert TooManySignups();

// ensure that the public key is on the baby jubjub curve
if (!CurveBabyJubJub.isOnCurve(_pubKey.x, _pubKey.y)) {
Expand Down

0 comments on commit 7ebf165

Please sign in to comment.