Skip to content

Commit

Permalink
fix(circuits): enforce use of stateIndex from message
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrlc03 committed Feb 8, 2024
1 parent e38eb77 commit efc1626
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 64 deletions.
32 changes: 25 additions & 7 deletions circuits/circom/processMessages.circom
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ template ProcessMessages(
component processors[batchSize];
// topup type processor
component processors2[batchSize];

for (var i = batchSize - 1; i >= 0; i --) {
// process it as vote type message
processors[i] = ProcessOne(stateTreeDepth, voteOptionTreeDepth);
Expand Down Expand Up @@ -349,16 +350,20 @@ template ProcessMessages(
<== currentStateLeavesPathElements[i][j][k];
}
}


// pick the correct result by msg type
tmpStateRoot1[i] <== processors[i].newStateRoot * (2 - msgs[i][0]);
tmpStateRoot2[i] <== processors2[i].newStateRoot * (msgs[i][0] - 1);
tmpBallotRoot1[i] <== processors[i].newBallotRoot * (2 - msgs[i][0]);
tmpBallotRoot2[i] <== ballotRoots[i + 1] * (msgs[i][0] - 1);
stateRoots[i] <== tmpStateRoot1[i] + tmpStateRoot2[i];
ballotRoots[i] <== tmpBallotRoot1[i] + tmpBallotRoot2[i];

}

component sbCommitmentHasher = Hasher3();

sbCommitmentHasher.in[0] <== stateRoots[0];
sbCommitmentHasher.in[1] <== ballotRoots[0];
sbCommitmentHasher.in[2] <== newSbSalt;
Expand Down Expand Up @@ -526,22 +531,34 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) {

// -----------------------------------------------------------------------
// 2. If msgType = 0 and isValid is 0, generate indices for leaf 0
// Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType
// Otherwise, generate indices for commmand.stateIndex or topupStateIndex depending on msgType
signal indexByType;
signal tmpIndex1;
signal tmpIndex2;
tmpIndex1 <== cmdStateIndex * (2 - msgType);
tmpIndex2 <== topupStateIndex * (msgType - 1);
indexByType <== tmpIndex1 + tmpIndex2;

component stateIndexMux = Mux1();
stateIndexMux.s <== transformer.isValid + msgType - 1;
stateIndexMux.c[0] <== 0;
stateIndexMux.c[1] <== indexByType;
// we can validate if the state index is within the numSignups
// if not, we use 0
// this is because decryption of an invalid message
// might result in random packed vals
component validStateLeafIndex = SafeLessThan(252);
validStateLeafIndex.in[0] <== indexByType;
validStateLeafIndex.in[1] <== numSignUps;

component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth);
stateLeafPathIndices.in <== stateIndexMux.out;
// use a mux to pick the correct index
component indexMux = Mux1();
indexMux.s <== validStateLeafIndex.out;
indexMux.c[0] <== 0;
indexMux.c[1] <== indexByType;

// @note that we expect a coordinator to send the state leaf corresponding to a message
// which specifies a valid state index. If this is not the case, the
// proof will fail to generate.
component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth);
stateLeafPathIndices.in <== indexMux.out;

// -----------------------------------------------------------------------
// 3. Verify that the original state leaf exists in the given state root
component stateLeafQip = QuinTreeInclusionProof(stateTreeDepth);
Expand Down Expand Up @@ -572,6 +589,7 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) {
ballotQip.path_elements[i][j] <== ballotPathElements[i][j];
}
}

ballotQip.root === currentBallotRoot;

// -----------------------------------------------------------------------
Expand Down
24 changes: 18 additions & 6 deletions circuits/circom/processMessagesNonQv.circom
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,25 @@ template ProcessOneNonQv(stateTreeDepth, voteOptionTreeDepth) {
tmpIndex2 <== topupStateIndex * (msgType - 1);
indexByType <== tmpIndex1 + tmpIndex2;

component stateIndexMux = Mux1();
stateIndexMux.s <== transformer.isValid + msgType - 1;
stateIndexMux.c[0] <== 0;
stateIndexMux.c[1] <== indexByType;

// we can validate if the state index is within the numSignups
// if not, we use 0
// this is because decryption of an invalid message
// might result in random packed vals
component validStateLeafIndex = SafeLessThan(252);
validStateLeafIndex.in[0] <== indexByType;
validStateLeafIndex.in[1] <== numSignUps;

// use a mux to pick the correct index
component indexMux = Mux1();
indexMux.s <== validStateLeafIndex.out;
indexMux.c[0] <== 0;
indexMux.c[1] <== indexByType;

// @note that we expect a coordinator to send the state leaf corresponding to a message
// which specifies a valid state index. If this is not the case, the
// proof will fail to generate.
component stateLeafPathIndices = QuinGeneratePathIndices(stateTreeDepth);
stateLeafPathIndices.in <== stateIndexMux.out;
stateLeafPathIndices.in <== indexMux.out;

// -----------------------------------------------------------------------
// 3. Verify that the original state leaf exists in the given state root
Expand Down
2 changes: 1 addition & 1 deletion circuits/circom/stateLeafAndBallotTransformer.circom
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ template StateLeafAndBallotTransformer() {
messageValidator.voteWeight <== cmdNewVoteWeight;

// if the message is valid then we swap out the public key
// we have to do this in two Mux one for pucKey[0]
// we have to do this in two Mux one for pubKey[0]
// and one for pubKey[1]
component newSlPubKey0Mux = Mux1();
newSlPubKey0Mux.s <== messageValidator.isValid;
Expand Down
31 changes: 26 additions & 5 deletions circuits/ts/__tests__/ProcessMessages.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect } from "chai";
import { type WitnessTester } from "circomkit";
import { MaciState, Poll, packProcessMessageSmallVals, STATE_TREE_ARITY } from "maci-core";
import { hash5, IncrementalQuinTree, NOTHING_UP_MY_SLEEVE, AccQueue } from "maci-crypto";
import { PrivKey, Keypair, PCommand, Message, Ballot } from "maci-domainobjs";
import { PrivKey, Keypair, PCommand, Message, Ballot, PubKey } from "maci-domainobjs";

import { IProcessMessagesInputs } from "../types";

Expand Down Expand Up @@ -76,7 +76,7 @@ describe("ProcessMessage circuit", function test() {
describe("1 user, 2 messages", () => {
const maciState = new MaciState(STATE_TREE_DEPTH);
const voteWeight = BigInt(9);
const voteOptionIndex = BigInt(0);
const voteOptionIndex = BigInt(1);
let stateIndex: bigint;
let pollId: bigint;
let poll: Poll;
Expand All @@ -101,6 +101,26 @@ describe("ProcessMessage circuit", function test() {
poll = maciState.polls.get(pollId)!;
poll.updatePoll(BigInt(maciState.stateLeaves.length));

const nothing = new Message(1n, [
8370432830353022751713833565135785980866757267633941821328460903436894336785n,
0n,
0n,
0n,
0n,
0n,
0n,
0n,
0n,
0n,
]);

const encP = new PubKey([
10457101036533406547632367118273992217979173478358440826365724437999023779287n,
19824078218392094440610104313265183977899662750282163392862422243483260492317n,
]);

poll.publishMessage(nothing, encP);

// First command (valid)
const command = new PCommand(
stateIndex, // BigInt(1),
Expand Down Expand Up @@ -144,6 +164,7 @@ describe("ProcessMessage circuit", function test() {
STATE_TREE_ARITY,
NOTHING_UP_MY_SLEEVE,
);
accumulatorQueue.enqueue(nothing.hash(encP));
accumulatorQueue.enqueue(message.hash(ecdhKeypair.pubKey));
accumulatorQueue.enqueue(message2.hash(ecdhKeypair2.pubKey));
accumulatorQueue.mergeSubRoots(0);
Expand Down Expand Up @@ -187,7 +208,7 @@ describe("ProcessMessage circuit", function test() {
BigInt(maxValues.maxVoteOptions),
BigInt(poll.maciStateRef.numSignUps),
0,
2,
3,
);

// Test the ProcessMessagesInputHasher circuit
Expand Down Expand Up @@ -554,7 +575,7 @@ describe("ProcessMessage circuit", function test() {

// Second batch is not a full batch
const numMessages = messageBatchSize * NUM_BATCHES - 1;
for (let i = 0; i < numMessages; i += 1) {
for (let i = 0; i < 6; i += 1) {
const command = new PCommand(
BigInt(index),
userKeypair.pubKey,
Expand All @@ -572,7 +593,7 @@ describe("ProcessMessage circuit", function test() {
selectedPoll?.publishMessage(message, ecdhKeypair.pubKey);
}

for (let i = 0; i < NUM_BATCHES; i += 1) {
for (let i = 0; i < 2; i += 1) {
const inputs = selectedPoll?.processMessages(id) as unknown as IProcessMessagesInputs;
// eslint-disable-next-line no-await-in-loop
const witness = await circuit.calculateWitness(inputs);
Expand Down
46 changes: 34 additions & 12 deletions cli/tests/e2e/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import { cleanVanilla, isArm } from "../utils";
/**
Test scenarios:
1 signup, 1 message
4 signups, 6 messages
4 signups, 8 messages
5 signups, 1 message
8 signups, 10 messages
4 signups, 4 messages
Expand Down Expand Up @@ -190,7 +190,7 @@ describe("e2e tests", function test() {
});
});

describe("4 signups, 6 messages", () => {
describe("4 signups, 8 messages", () => {
after(() => {
cleanVanilla();
});
Expand All @@ -212,7 +212,29 @@ describe("e2e tests", function test() {
}
});

it("should publish six messages", async () => {
it("should publish eight messages", async () => {
await publish({
pubkey: users[0].pubKey.serialize(),
stateIndex: 1n,
voteOptionIndex: 0n,
nonce: 2n,
pollId: 0n,
newVoteWeight: 4n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[0].privKey.serialize(),
});
await publish({
pubkey: users[0].pubKey.serialize(),
stateIndex: 1n,
voteOptionIndex: 0n,
nonce: 2n,
pollId: 0n,
newVoteWeight: 3n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[0].privKey.serialize(),
});
await publish({
pubkey: users[0].pubKey.serialize(),
stateIndex: 1n,
Expand All @@ -227,7 +249,7 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[1].pubKey.serialize(),
stateIndex: 2n,
voteOptionIndex: 0n,
voteOptionIndex: 2n,
nonce: 1n,
pollId: 0n,
newVoteWeight: 9n,
Expand All @@ -238,7 +260,7 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[2].pubKey.serialize(),
stateIndex: 3n,
voteOptionIndex: 0n,
voteOptionIndex: 2n,
nonce: 1n,
pollId: 0n,
newVoteWeight: 9n,
Expand All @@ -249,29 +271,29 @@ describe("e2e tests", function test() {
await publish({
pubkey: users[3].pubKey.serialize(),
stateIndex: 4n,
voteOptionIndex: 0n,
nonce: 1n,
voteOptionIndex: 2n,
nonce: 3n,
pollId: 0n,
newVoteWeight: 9n,
newVoteWeight: 3n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[3].privKey.serialize(),
});
await publish({
pubkey: users[3].pubKey.serialize(),
stateIndex: 4n,
voteOptionIndex: 0n,
nonce: 1n,
voteOptionIndex: 2n,
nonce: 2n,
pollId: 0n,
newVoteWeight: 9n,
newVoteWeight: 2n,
maciContractAddress: maciAddresses.maciAddress,
salt: genRandomSalt(),
privateKey: users[3].privKey.serialize(),
});
await publish({
pubkey: users[3].pubKey.serialize(),
stateIndex: 4n,
voteOptionIndex: 0n,
voteOptionIndex: 1n,
nonce: 1n,
pollId: 0n,
newVoteWeight: 9n,
Expand Down
4 changes: 3 additions & 1 deletion cli/ts/commands/genProofs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ export const genProofs = async ({
while (poll.hasUnprocessedMessages()) {
// process messages in batches
const circuitInputs = poll.processMessages(pollId, useQuadraticVoting, quiet) as unknown as CircuitInputs;

try {
// generate the proof for this batch
// eslint-disable-next-line no-await-in-loop
Expand All @@ -297,11 +298,12 @@ export const genProofs = async ({
witnessExePath: processWitgen,
wasmPath: processWasm,
});

// verify it
// eslint-disable-next-line no-await-in-loop
const isValid = await verifyProof(r.publicSignals, r.proof, processVk);
if (!isValid) {
logError("Error: generated an invalid proof");
throw new Error("Generated an invalid proof");
}

const thisProof = {
Expand Down

0 comments on commit efc1626

Please sign in to comment.