From f202705c52ac505cf0264b7994cd6961fcbae203 Mon Sep 17 00:00:00 2001 From: ctrlc03 <93448202+ctrlc03@users.noreply.github.com> Date: Sat, 20 Jan 2024 14:07:17 +0000 Subject: [PATCH 1/6] feat(ecdh): implement ECDH circom circuit --- packages/circuits/circom/circuits.json | 6 ++ packages/circuits/circom/ecdh.circom | 30 ++++++++++ packages/circuits/tests/ecdh.test.ts | 82 ++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 packages/circuits/circom/ecdh.circom create mode 100644 packages/circuits/tests/ecdh.test.ts diff --git a/packages/circuits/circom/circuits.json b/packages/circuits/circom/circuits.json index c56bc422d..f850c99c0 100644 --- a/packages/circuits/circom/circuits.json +++ b/packages/circuits/circom/circuits.json @@ -29,5 +29,11 @@ "file": "poseidon-cipher", "template": "PoseidonPerm", "params": [2] + }, + "ecdh": { + "file": "ecdh", + "template": "Ecdh", + "pubs": ["publicKey"], + "params": [] } } diff --git a/packages/circuits/circom/ecdh.circom b/packages/circuits/circom/ecdh.circom new file mode 100644 index 000000000..3af6fb2a1 --- /dev/null +++ b/packages/circuits/circom/ecdh.circom @@ -0,0 +1,30 @@ +pragma circom 2.1.5; + +// circomlib imports +include "./bitify.circom"; +include "./escalarmulany.circom"; + +// ECDH Is a a template which allows to generate a shared secret +// from a private key and a public key +// on the baby jubjub curve +// It is important that the private key is hashed and pruned first +// which can be accomplished using the function +// deriveScalar from @zk-kit/baby-jubjub +template Ecdh() { + // the private key must pass through deriveScalar first + signal input privateKey; + signal input publicKey[2]; + + signal output sharedKey[2]; + + // convert the private key to its bits representation + var out[253]; + out = Num2Bits(253)(privateKey); + + // multiply the public key by the private key + var mulFix[2]; + mulFix = EscalarMulAny(253)(out, publicKey); + + // we can then wire the output to the shared secret signal + sharedKey <== mulFix; +} diff --git a/packages/circuits/tests/ecdh.test.ts b/packages/circuits/tests/ecdh.test.ts new file mode 100644 index 000000000..5da98bfa5 --- /dev/null +++ b/packages/circuits/tests/ecdh.test.ts @@ -0,0 +1,82 @@ +import { WitnessTester } from "circomkit" +import { deriveSecretScalar } from "@zk-kit/eddsa-poseidon" + +import { circomkit, genEcdhSharedKey, genPublicKey, genRandomBabyJubValue } from "./common" + +describe("ECDH Shared Key derivation circuit", () => { + let circuit: WitnessTester<["privateKey", "publicKey"], ["sharedKey"]> + + before(async () => { + circuit = await circomkit.WitnessTester("ecdh", { + file: "ecdh", + template: "Ecdh" + }) + }) + + it("should correctly compute an ECDH shared key", async () => { + const privateKey1 = genRandomBabyJubValue() + const privateKey2 = genRandomBabyJubValue() + const publicKey2 = genPublicKey(privateKey2) + + // generate a shared key between the first private key and the second public key + const ecdhSharedKey = genEcdhSharedKey(privateKey1, publicKey2) + + const circuitInputs = { + privateKey: BigInt(deriveSecretScalar(privateKey1)), + publicKey: publicKey2 + } + + await circuit.expectPass(circuitInputs, { sharedKey: [ecdhSharedKey[0], ecdhSharedKey[1]] }) + }) + + it("should generate the same shared key from the same keypairs", async () => { + const privateKey1 = genRandomBabyJubValue() + const privateKey2 = genRandomBabyJubValue() + const publicKey1 = genPublicKey(privateKey1) + const publicKey2 = genPublicKey(privateKey2) + + // generate a shared key between the first private key and the second public key + const ecdhSharedKey = genEcdhSharedKey(privateKey1, publicKey2) + const ecdhSharedKey2 = genEcdhSharedKey(privateKey2, publicKey1) + + const circuitInputs = { + privateKey: BigInt(deriveSecretScalar(privateKey1)), + publicKey: publicKey2 + } + + const circuitInputs2 = { + privateKey: BigInt(deriveSecretScalar(privateKey2)), + publicKey: publicKey1 + } + + // calculate first time witness and check contraints + const witness = await circuit.calculateWitness(circuitInputs) + await circuit.expectConstraintPass(witness) + + const out = await circuit.readWitnessSignals(witness, ["sharedKey"]) + await circuit.expectPass(circuitInputs, { sharedKey: ecdhSharedKey }) + await circuit.expectPass(circuitInputs2, { sharedKey: out.sharedKey }) + await circuit.expectPass(circuitInputs2, { sharedKey: ecdhSharedKey2 }) + }) + + it("should generate the same ECDH key consistently for the same inputs", async () => { + const privateKey1 = BigInt(deriveSecretScalar(genRandomBabyJubValue())) + const privateKey2 = genRandomBabyJubValue() + const publicKey2 = genPublicKey(privateKey2) + + const circuitInputs = { + privateKey: privateKey1, + publicKey: publicKey2 + } + + // calculate first time witness and check contraints + const witness = await circuit.calculateWitness(circuitInputs) + await circuit.expectConstraintPass(witness) + + // read out + const out = await circuit.readWitnessSignals(witness, ["sharedKey"]) + + // calculate again + await circuit.expectPass(circuitInputs, { sharedKey: out.sharedKey }) + }) +}) From 5da36608ce8e382117f706e4779d510a68217ecd Mon Sep 17 00:00:00 2001 From: Jeeiii Date: Tue, 13 Feb 2024 17:25:56 +0100 Subject: [PATCH 2/6] feat: add safe comparators templates --- packages/circuits/circom/circuits.json | 24 ++++ .../circuits/circom/safe-comparators.circom | 56 +++++++++ .../circuits/tests/safe-comparators.test.ts | 116 ++++++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 packages/circuits/circom/safe-comparators.circom create mode 100644 packages/circuits/tests/safe-comparators.test.ts diff --git a/packages/circuits/circom/circuits.json b/packages/circuits/circom/circuits.json index f850c99c0..6c0e24fc9 100644 --- a/packages/circuits/circom/circuits.json +++ b/packages/circuits/circom/circuits.json @@ -35,5 +35,29 @@ "template": "Ecdh", "pubs": ["publicKey"], "params": [] + }, + "safe-less-than": { + "file": "safe-comparators", + "template": "SafeLessThan", + "pubs": ["in"], + "params": [252] + }, + "safe-less-eq-than": { + "file": "safe-comparators", + "template": "SafeLessEqThan", + "pubs": ["in"], + "params": [252] + }, + "safe-greater-than": { + "file": "safe-comparators", + "template": "SafeGreaterThan", + "pubs": ["in"], + "params": [252] + }, + "safe-greater-eq-than": { + "file": "safe-comparators", + "template": "SafeGreaterEqThan", + "pubs": ["in"], + "params": [252] } } diff --git a/packages/circuits/circom/safe-comparators.circom b/packages/circuits/circom/safe-comparators.circom new file mode 100644 index 000000000..68227d750 --- /dev/null +++ b/packages/circuits/circom/safe-comparators.circom @@ -0,0 +1,56 @@ +pragma circom 2.0.0; + +include "./bitify.circom"; + +// Template for safely comparing if one input is less than another, +// ensuring inputs are within a specified bit-length. +template SafeLessThan(n) { + // Ensure the bit-length does not exceed 252 bits. + assert(n <= 252); + + signal input in[2]; + signal output out; + + // Convert both inputs to their bit representations to ensure + // they fit within 'n' bits. + var n2b1[n]; + n2b1 = Num2Bits(n)(in[0]); + + var n2b2[n]; + n2b2 = Num2Bits(n)(in[1]); + + // Additional conversion to handle arithmetic operation and capture the comparison result. + var n2b[n+1]; + n2b = Num2Bits(n + 1)(in[0] + (1< { + describe("SafeLessThan", () => { + let circuit: WitnessTester<["in"], ["out"]> + + // Test values + const inValues = [5, 10] // in[0] < in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is less than 10. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("SafeLessThan", { + file: "safe-comparators", + template: "SafeLessThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly compare two numbers", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("SafeLessEqThan", () => { + let circuit: WitnessTester<["in"], ["out"]> + + // Test values + const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is equal to 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("SafeLessEqThan", { + file: "safe-comparators", + template: "SafeLessEqThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly compare two numbers", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("SafeGreaterThan", () => { + let circuit: WitnessTester<["in"], ["out"]> + + // Test values + const inValues = [10, 5] // in[0] > in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 10 is greater than 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("SafeGreaterThan", { + file: "safe-comparators", + template: "SafeGreaterThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly compare two numbers", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("SafeGreaterEqThan", () => { + let circuit: WitnessTester<["in"], ["out"]> + + // Test values + const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is equal to 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("SafeGreaterEqThan", { + file: "safe-comparators", + template: "SafeGreaterEqThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly compare two numbers", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) +}) From dd0a4537258cf654930cc086a545497b2d48aa14 Mon Sep 17 00:00:00 2001 From: Jeeiii Date: Wed, 14 Feb 2024 16:50:36 +0100 Subject: [PATCH 3/6] feat: add float circuits library --- packages/circuits/circom/circuits.json | 40 ++++ packages/circuits/circom/float.circom | 192 +++++++++++++++++++ packages/circuits/tests/float.test.ts | 252 +++++++++++++++++++++++++ 3 files changed, 484 insertions(+) create mode 100644 packages/circuits/circom/float.circom create mode 100644 packages/circuits/tests/float.test.ts diff --git a/packages/circuits/circom/circuits.json b/packages/circuits/circom/circuits.json index f850c99c0..5880b8b51 100644 --- a/packages/circuits/circom/circuits.json +++ b/packages/circuits/circom/circuits.json @@ -35,5 +35,45 @@ "template": "Ecdh", "pubs": ["publicKey"], "params": [] + }, + "msb": { + "file": "float", + "template": "MSB", + "params": [252] + }, + "shift": { + "file": "float", + "template": "Shift", + "params": [252] + }, + "integer-division": { + "file": "float", + "template": "IntegerDivision", + "params": [252] + }, + "to-float": { + "file": "float", + "template": "ToFloat", + "params": [74] + }, + "division-from-float": { + "file": "float", + "template": "DivisionFromFloat", + "params": [74, 251] + }, + "division-from-normal": { + "file": "float", + "template": "DivisionFromNormal", + "params": [74, 251] + }, + "multiplication-from-float": { + "file": "float", + "template": "MultiplicationFromFloat", + "params": [74, 251] + }, + "multiplication-from-normal": { + "file": "float", + "template": "MultiplicationFromNormal", + "params": [74, 251] } } diff --git a/packages/circuits/circom/float.circom b/packages/circuits/circom/float.circom new file mode 100644 index 000000000..33796a413 --- /dev/null +++ b/packages/circuits/circom/float.circom @@ -0,0 +1,192 @@ +pragma circom 2.1.5; + +include "./bitify.circom"; +include "./comparators.circom"; +include "./mux1.circom"; + +// Template to determine the most significant bit (MSB) of an input number. +template MSB(n) { + signal input in; + signal output out; + + // Convert the number to its bit representation. + var n2b[n]; + n2b = Num2Bits(n)(in); + + // Assign the MSB to the output. + out <== n2b[n-1]; +} + +// Template for bit-shifting a dividend and partial remainder. +template Shift(n) { + signal input divident; // Dividend input. + signal input rem; // Partial remainder input. + + signal output divident1; // Output for the shifted dividend. + signal output rem1; // Output for the updated partial remainder. + + // Determine the MSB of the dividend. + var lmsb; + lmsb = MSB(n)(divident); + + // Shift the dividend. + divident1 <== divident - lmsb * 2 ** (n - 1); + + // Update the partial remainder. + rem1 <== rem * 2 + lmsb; +} + +// Template for performing integer division. +template IntegerDivision(n) { + signal input a; // Dividend. + signal input b; // Divisor. + + signal output c; // Quotient. + + // Ensure inputs are within the valid range. + var lta; + var ltb; + + lta = LessThan(252)([a, 2**n]); + ltb = LessThan(252)([b, 2**n]); + + assert(lta == 1); + assert(ltb == 1); + + // Ensure the divisor 'b' is not zero. + var isz; + + isz = IsZero()(b); + + assert(isz == 0); + + // Prepare variables for division. + var divident = a; + var rem = 0; + + var bits[n]; + + // Loop to perform division through bit-shifting and subtraction. + for (var i = n - 1; i >= 0; i--) { + // Shift 'divident' and 'rem' and determine if 'b' can be subtracted from the new 'rem'. + var divident1; + var rem1; + + (divident1, rem1) = Shift(i + 1)(divident, rem); + + // Determine if 'b' <= 'rem'. + var canSubtract; + + canSubtract = LessEqThan(n)([b, rem1]); + + // Select 1 if 'b' can be subtracted (i.e., 'b' <= 'rem'), else select 0. + var subtractBit; + + subtractBit = Mux1()([0, 1], canSubtract); + + // Subtract 'b' from 'rem' if possible, and set the corresponding bit in 'bits'. + bits[i] = subtractBit; + + rem = rem1 - b * subtractBit; + + // Prepare 'divident' for the next iteration. + divident = divident1; + } + + // Convert the bit array representing the quotient into a number. + c <== Bits2Num(n)(bits); +} + +// Converts an integer to its floating-point representation by multiplying it with 10^W. +template ToFloat(W) { + // Assert W to ensure the result is within the range of 2^252. + assert(W < 75); + + signal input in; + + signal output out; + + // Ensure the input multiplied by 10^W is less than 10^75 to prevent overflow. + var lt; + + lt = LessEqThan(252)([in, 10 ** (75 - W)]); + + assert(lt == 1); + + // Convert the integer to floating-point by multiplying with 10^W. + out <== in * (10**W); +} + +// Performs division on floating-point numbers represented with W decimal digits. +template DivisionFromFloat(W, n) { + // Ensure W is within the valid range for floating-point representation. + assert(W < 75); + // Ensure n, the bit-width of inputs, is within a valid range. + assert(n < 252); + + signal input a; // Numerator. + signal input b; // Denominator. + + signal output c; // Quotient. + + // Ensure the numerator 'a' is within the range of valid floating-point numbers. + var lt; + + lt = LessEqThan(252)([a, 10 ** (75 - W)]); + + assert(lt == 1); + + // Use IntegerDivision for division operation. + c <== IntegerDivision(n)(a * (10 ** W), b); +} + +// Performs division on integers by first converting them to floating-point representation. +template DivisionFromNormal(W, n) { + signal input a; // Numerator. + signal input b; // Denominator. + + signal output c; // Quotient. + + // Convert input to float and perform division. + c <== DivisionFromFloat(W, n)(ToFloat(W)(a), ToFloat(W)(b)); +} + +// Performs multiplication on floating-point numbers and converts the result back to integer form. +template MultiplicationFromFloat(W, n) { + // Ensure W is within the valid range for floating-point representation. + assert(W < 75); + // Ensure n, the bit-width of inputs, is within a valid range. + assert(n < 252); + // Ensure scaling factor is within the range of 'n' bits. + assert(10**W < 2**n); + + signal input a; // Multiplicand. + signal input b; // Multiplier. + + signal output c; // Product. + + // Ensure both inputs 'a' and 'b' are within a valid range for multiplication. + var lta; + var ltb; + + lta = LessEqThan(252)([a, 2 ** 126]); + ltb = LessEqThan(252)([b, 2 ** 126]); + + assert(lta == 1); + assert(ltb == 1); + + // Perform integer division after multiplication to adjust the result back to W decimal digits. + c <== IntegerDivision(n)(a * b, 10 ** W); +} + +// Performs multiplication on integers by first converting them to floating-point representation. +template MultiplicationFromNormal(W, n) { + signal input a; // Multiplicand. + signal input b; // Multiplier. + + signal output c; // Product. + + + // Convert input to float and perform multiplication. + c <== MultiplicationFromFloat(W, n)(ToFloat(W)(a), ToFloat(W)(b)); +} \ No newline at end of file diff --git a/packages/circuits/tests/float.test.ts b/packages/circuits/tests/float.test.ts new file mode 100644 index 000000000..2c9bfe2f0 --- /dev/null +++ b/packages/circuits/tests/float.test.ts @@ -0,0 +1,252 @@ +import { WitnessTester } from "circomkit" +import { circomkit } from "./common" + +describe("float", () => { + describe("MSB", () => { + let circuit: WitnessTester<["in"], ["out"]> + + // Test values + const inValues = [1000] + const expectedOut = 0 + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("MSB", { + file: "float", + template: "MSB", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly find the most significant bit", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("Shift", () => { + let circuit: WitnessTester<["divident", "rem"], ["divident1", "rem1"]> + + // Test values + const inValues = { + divident: 10, + rem: 1 + } + + const INPUT = { + divident: inValues.divident, + rem: inValues.rem + } + + const OUTPUT = { + divident1: inValues.divident, + rem1: inValues.rem * 2 + } + + before(async () => { + circuit = await circomkit.WitnessTester("Shift", { + file: "float", + template: "Shift", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly bit-shifting a dividend and partial remainder", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("IntegerDivision", () => { + let circuit: WitnessTester<["a", "b"], ["c"]> + + // Test values + const inValues = { + a: 10, + b: 2 + } + const expectedOut = 5 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("IntegerDivision", { + file: "float", + template: "IntegerDivision", + params: [252] // Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly perform the integer division", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("ToFloat", () => { + let circuit: WitnessTester<["in"], ["out"]> + + // Test values + const inValues = 10 + const expectedOut = 1000 // 10.00 + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("ToFloat", { + file: "float", + template: "ToFloat", + params: [2] // Assuming we're working within the range of 2^252. + }) + }) + + it("Should correctly convert to float", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("DivisionFromFloat", () => { + let circuit: WitnessTester<["a", "b"], ["c"]> + + // Test values + const inValues = { + a: 1000, // 10.00 + b: 200 // 2.00 + } + const expectedOut = 500 // 5.00 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("DivisionFromFloat", { + file: "float", + template: "DivisionFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly perform the division from float", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("DivisionFromNormal", () => { + let circuit: WitnessTester<["a", "b"], ["c"]> + + // Test values + const inValues = { + a: 10, + b: 2 + } + const expectedOut = 500 // 5.00 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("DivisionFromNormal", { + file: "float", + template: "DivisionFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly perform the division from normal", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("MultiplicationFromFloat", () => { + let circuit: WitnessTester<["a", "b"], ["c"]> + + // Test values + const inValues = { + a: 1000, // 10.00 + b: 200 // 2.00 + } + const expectedOut = 2000 // 20.00 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("MultiplicationFromFloat", { + file: "float", + template: "MultiplicationFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly perform the multiplication from float", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) + + describe("MultiplicationFromNormal", () => { + let circuit: WitnessTester<["a", "b"], ["c"]> + + // Test values + const inValues = { + a: 10, + b: 2 + } + const expectedOut = 2000 // 20.00 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + before(async () => { + circuit = await circomkit.WitnessTester("MultiplicationFromNormal", { + file: "float", + template: "MultiplicationFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + }) + + it("Should correctly perform the multiplication from normal", async () => { + await circuit.expectPass(INPUT, OUTPUT) + }) + }) +}) From ff46d667cce2a7cf0b6a4e72f3e5ef87868be7ea Mon Sep 17 00:00:00 2001 From: Jeeiii Date: Wed, 14 Feb 2024 17:01:12 +0100 Subject: [PATCH 4/6] test: add more tests to comparators; style-convention for comments --- .../circuits/circom/safe-comparators.circom | 12 +- .../circuits/tests/safe-comparators.test.ts | 256 +++++++++++++++--- 2 files changed, 220 insertions(+), 48 deletions(-) diff --git a/packages/circuits/circom/safe-comparators.circom b/packages/circuits/circom/safe-comparators.circom index 68227d750..03cadc069 100644 --- a/packages/circuits/circom/safe-comparators.circom +++ b/packages/circuits/circom/safe-comparators.circom @@ -38,8 +38,10 @@ template SafeLessEqThan(n) { // Template for safely comparing if one input is greater than another. template SafeGreaterThan(n) { - signal input in[2]; // Two inputs to compare. - signal output out; // Output signal indicating comparison result. + // Two inputs to compare. + signal input in[2]; + // Output signal indicating comparison result. + signal output out; // Invert the inputs for SafeLessThan to check if in[1] is less than in[0]. out <== SafeLessThan(n)([in[1], in[0]]); @@ -47,8 +49,10 @@ template SafeGreaterThan(n) { // Template to check if one input is greater than or equal to another. template SafeGreaterEqThan(n) { - signal input in[2]; // Two inputs to compare. - signal output out; // Output signal indicating comparison result. + // Two inputs to compare. + signal input in[2]; + // Output signal indicating comparison result. + signal output out; // Invert the inputs and adjust for equality in SafeLessThan to // check if in[1] is less than or equal to in[0]. diff --git a/packages/circuits/tests/safe-comparators.test.ts b/packages/circuits/tests/safe-comparators.test.ts index 0d1a9967a..04b9907e6 100644 --- a/packages/circuits/tests/safe-comparators.test.ts +++ b/packages/circuits/tests/safe-comparators.test.ts @@ -5,27 +5,69 @@ describe("safe-comparators", () => { describe("SafeLessThan", () => { let circuit: WitnessTester<["in"], ["out"]> - // Test values - const inValues = [5, 10] // in[0] < in[1], expecting 'out' to be 1. - const expectedOut = 1 // Since 5 is less than 10. + it("Should correctly compare two numbers [x, x]", async () => { + // Test values + const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 0. + const expectedOut = 0 // Since 5 is equal to 5. - const INPUT = { - in: inValues - } + const INPUT = { + in: inValues + } - const OUTPUT = { - out: expectedOut - } + const OUTPUT = { + out: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("SafeLessThan", { file: "safe-comparators", template: "SafeLessThan", params: [252] // Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly compare two numbers", async () => { + it("Should correctly compare two numbers [x-1, x]", async () => { + // Test values + const inValues = [5, 6] // in[0] < in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is less than 6. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeLessThan", { + file: "safe-comparators", + template: "SafeLessThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly compare two numbers [x, x-1]", async () => { + // Test values + const inValues = [6, 5] // in[0] > in[1], expecting 'out' to be 0. + const expectedOut = 0 // Since 6 is greater than 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeLessThan", { + file: "safe-comparators", + template: "SafeLessThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -33,27 +75,69 @@ describe("safe-comparators", () => { describe("SafeLessEqThan", () => { let circuit: WitnessTester<["in"], ["out"]> - // Test values - const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 1. - const expectedOut = 1 // Since 5 is equal to 5. + it("Should correctly compare two numbers [x, x]", async () => { + // Test values + const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is equal to 5. - const INPUT = { - in: inValues - } + const INPUT = { + in: inValues + } - const OUTPUT = { - out: expectedOut - } + const OUTPUT = { + out: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("SafeLessEqThan", { file: "safe-comparators", template: "SafeLessEqThan", params: [252] // Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly compare two numbers", async () => { + it("Should correctly compare two numbers [x-1, x]", async () => { + // Test values + const inValues = [5, 6] // in[0] < in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is less than 6. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeLessEqThan", { + file: "safe-comparators", + template: "SafeLessEqThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly compare two numbers [x, x-1]", async () => { + // Test values + const inValues = [6, 5] // in[0] > in[1], expecting 'out' to be 0. + const expectedOut = 0 // Since 6 is greater than 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeLessEqThan", { + file: "safe-comparators", + template: "SafeLessEqThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -61,27 +145,69 @@ describe("safe-comparators", () => { describe("SafeGreaterThan", () => { let circuit: WitnessTester<["in"], ["out"]> - // Test values - const inValues = [10, 5] // in[0] > in[1], expecting 'out' to be 1. - const expectedOut = 1 // Since 10 is greater than 5. + it("Should correctly compare two numbers [x, x]", async () => { + // Test values + const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 0. + const expectedOut = 0 // Since 5 is equal to 5. - const INPUT = { - in: inValues - } + const INPUT = { + in: inValues + } - const OUTPUT = { - out: expectedOut - } + const OUTPUT = { + out: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("SafeGreaterThan", { file: "safe-comparators", template: "SafeGreaterThan", params: [252] // Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly compare two numbers", async () => { + it("Should correctly compare two numbers [x-1, x]", async () => { + // Test values + const inValues = [5, 6] // in[0] < in[1], expecting 'out' to be 0. + const expectedOut = 0 // Since 5 is less than 6. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeGreaterThan", { + file: "safe-comparators", + template: "SafeGreaterThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly compare two numbers [x, x-1]", async () => { + // Test values + const inValues = [6, 5] // in[0] > in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 6 is greater than 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeGreaterThan", { + file: "safe-comparators", + template: "SafeGreaterThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -89,27 +215,69 @@ describe("safe-comparators", () => { describe("SafeGreaterEqThan", () => { let circuit: WitnessTester<["in"], ["out"]> - // Test values - const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 1. - const expectedOut = 1 // Since 5 is equal to 5. + it("Should correctly compare two numbers [x, x]", async () => { + // Test values + const inValues = [5, 5] // in[0] === in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 5 is equal to 5. - const INPUT = { - in: inValues - } + const INPUT = { + in: inValues + } - const OUTPUT = { - out: expectedOut - } + const OUTPUT = { + out: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("SafeGreaterEqThan", { file: "safe-comparators", template: "SafeGreaterEqThan", params: [252] // Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly compare two numbers", async () => { + it("Should correctly compare two numbers [x-1, x]", async () => { + // Test values + const inValues = [5, 6] // in[0] < in[1], expecting 'out' to be 0. + const expectedOut = 0 // Since 5 is less than 6. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeGreaterEqThan", { + file: "safe-comparators", + template: "SafeGreaterEqThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly compare two numbers [x, x-1]", async () => { + // Test values + const inValues = [6, 5] // in[0] > in[1], expecting 'out' to be 1. + const expectedOut = 1 // Since 6 is greater than 5. + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("SafeGreaterEqThan", { + file: "safe-comparators", + template: "SafeGreaterEqThan", + params: [252] // Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) From 169b67f90d50f936d5f6a1fd110677a6ab5250e8 Mon Sep 17 00:00:00 2001 From: Jeeiii Date: Wed, 14 Feb 2024 17:43:04 +0100 Subject: [PATCH 5/6] test: add more tests to float math templates --- packages/circuits/circom/float.circom | 95 +++-- packages/circuits/tests/float.test.ts | 592 ++++++++++++++++++++++---- 2 files changed, 553 insertions(+), 134 deletions(-) diff --git a/packages/circuits/circom/float.circom b/packages/circuits/circom/float.circom index 33796a413..6a793260e 100644 --- a/packages/circuits/circom/float.circom +++ b/packages/circuits/circom/float.circom @@ -19,29 +19,35 @@ template MSB(n) { // Template for bit-shifting a dividend and partial remainder. template Shift(n) { - signal input divident; // Dividend input. - signal input rem; // Partial remainder input. - - signal output divident1; // Output for the shifted dividend. - signal output rem1; // Output for the updated partial remainder. + // Dividend. + signal input dividend; + // Remainder. + signal input remainder; + + // Shifted dividend. + signal output outDividend; + // Partial remainder (updated). + signal output outRemainder; // Determine the MSB of the dividend. var lmsb; - lmsb = MSB(n)(divident); + lmsb = MSB(n)(dividend); // Shift the dividend. - divident1 <== divident - lmsb * 2 ** (n - 1); + outDividend <== dividend - lmsb * 2 ** (n - 1); // Update the partial remainder. - rem1 <== rem * 2 + lmsb; + outRemainder <== remainder * 2 + lmsb; } // Template for performing integer division. template IntegerDivision(n) { - signal input a; // Dividend. - signal input b; // Divisor. - - signal output c; // Quotient. + // Dividend. + signal input a; + // Divisor. + signal input b; + // Quotient. + signal output c; // Ensure inputs are within the valid range. var lta; @@ -61,23 +67,23 @@ template IntegerDivision(n) { assert(isz == 0); // Prepare variables for division. - var divident = a; - var rem = 0; + var dividend = a; + var remainder = 0; var bits[n]; // Loop to perform division through bit-shifting and subtraction. for (var i = n - 1; i >= 0; i--) { - // Shift 'divident' and 'rem' and determine if 'b' can be subtracted from the new 'rem'. - var divident1; - var rem1; + // Shift 'dividend' and 'rem' and determine if 'b' can be subtracted from the new 'rem'. + var shiftedDividend; + var shiftedRem; - (divident1, rem1) = Shift(i + 1)(divident, rem); + (shiftedDividend, shiftedRem) = Shift(i + 1)(dividend, remainder); // Determine if 'b' <= 'rem'. var canSubtract; - canSubtract = LessEqThan(n)([b, rem1]); + canSubtract = LessEqThan(n)([b, shiftedRem]); // Select 1 if 'b' can be subtracted (i.e., 'b' <= 'rem'), else select 0. var subtractBit; @@ -87,10 +93,10 @@ template IntegerDivision(n) { // Subtract 'b' from 'rem' if possible, and set the corresponding bit in 'bits'. bits[i] = subtractBit; - rem = rem1 - b * subtractBit; + remainder = shiftedRem - b * subtractBit; - // Prepare 'divident' for the next iteration. - divident = divident1; + // Prepare 'dividend' for the next iteration. + dividend = shiftedDividend; } // Convert the bit array representing the quotient into a number. @@ -123,11 +129,13 @@ template DivisionFromFloat(W, n) { assert(W < 75); // Ensure n, the bit-width of inputs, is within a valid range. assert(n < 252); - - signal input a; // Numerator. - signal input b; // Denominator. - - signal output c; // Quotient. + + // Numerator. + signal input a; + // Denominator. + signal input b; + // Quotient. + signal output c; // Ensure the numerator 'a' is within the range of valid floating-point numbers. var lt; @@ -142,10 +150,12 @@ template DivisionFromFloat(W, n) { // Performs division on integers by first converting them to floating-point representation. template DivisionFromNormal(W, n) { - signal input a; // Numerator. - signal input b; // Denominator. - - signal output c; // Quotient. + // Numerator. + signal input a; + // Denominator. + signal input b; + // Quotient. + signal output c; // Convert input to float and perform division. c <== DivisionFromFloat(W, n)(ToFloat(W)(a), ToFloat(W)(b)); @@ -159,11 +169,13 @@ template MultiplicationFromFloat(W, n) { assert(n < 252); // Ensure scaling factor is within the range of 'n' bits. assert(10**W < 2**n); - - signal input a; // Multiplicand. - signal input b; // Multiplier. - - signal output c; // Product. + + // Multiplicand. + signal input a; + // Multiplier. + signal input b; + // Product. + signal output c; // Ensure both inputs 'a' and 'b' are within a valid range for multiplication. var lta; @@ -181,11 +193,12 @@ template MultiplicationFromFloat(W, n) { // Performs multiplication on integers by first converting them to floating-point representation. template MultiplicationFromNormal(W, n) { - signal input a; // Multiplicand. - signal input b; // Multiplier. - - signal output c; // Product. - + // Multiplicand. + signal input a; + // Multiplier. + signal input b; + // Product. + signal output c; // Convert input to float and perform multiplication. c <== MultiplicationFromFloat(W, n)(ToFloat(W)(a), ToFloat(W)(b)); diff --git a/packages/circuits/tests/float.test.ts b/packages/circuits/tests/float.test.ts index 2c9bfe2f0..370833515 100644 --- a/packages/circuits/tests/float.test.ts +++ b/packages/circuits/tests/float.test.ts @@ -5,48 +5,63 @@ describe("float", () => { describe("MSB", () => { let circuit: WitnessTester<["in"], ["out"]> - // Test values - const inValues = [1000] - const expectedOut = 0 - - const INPUT = { - in: inValues - } + it("Should throw when the number is negative", async () => { + // Test values + const inValues = [-1] - const OUTPUT = { - out: expectedOut - } + const INPUT = { + in: inValues + } - before(async () => { circuit = await circomkit.WitnessTester("MSB", { file: "float", template: "MSB", params: [252] // Assuming we're working within 252-bit numbers. }) + + await circuit.expectFail(INPUT) }) it("Should correctly find the most significant bit", async () => { + // Test values + const inValues = [1] + const expectedOut = 0 + + const INPUT = { + in: inValues + } + + const OUTPUT = { + out: expectedOut + } + + circuit = await circomkit.WitnessTester("MSB", { + file: "float", + template: "MSB", + params: [252] // Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) describe("Shift", () => { - let circuit: WitnessTester<["divident", "rem"], ["divident1", "rem1"]> + let circuit: WitnessTester<["dividend", "remainder"], ["outDividend", "outRemainder"]> // Test values const inValues = { - divident: 10, - rem: 1 + dividend: 10, + remainder: 1 } const INPUT = { - divident: inValues.divident, - rem: inValues.rem + dividend: inValues.dividend, + remainder: inValues.remainder } const OUTPUT = { - divident1: inValues.divident, - rem1: inValues.rem * 2 + outDividend: inValues.dividend, + outRemainder: inValues.remainder * 2 } before(async () => { @@ -65,31 +80,96 @@ describe("float", () => { describe("IntegerDivision", () => { let circuit: WitnessTester<["a", "b"], ["c"]> - // Test values - const inValues = { - a: 10, - b: 2 - } - const expectedOut = 5 + it("Should throw when trying to perform division per zero [x, 0]", async () => { + // Test values + const inValues = { + a: 10, + b: 0 + } - const INPUT = { - a: inValues.a, - b: inValues.b - } + const INPUT = { + a: inValues.a, + b: inValues.b + } - const OUTPUT = { - c: expectedOut - } + circuit = await circomkit.WitnessTester("IntegerDivision", { + file: "float", + template: "IntegerDivision", + params: [252] // Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should throw when trying to perform division per negative number [x, -x]", async () => { + // Test values + const inValues = { + a: 10, + b: -10 + } + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + circuit = await circomkit.WitnessTester("IntegerDivision", { + file: "float", + template: "IntegerDivision", + params: [252] // Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should correctly perform the integer division [0, x]", async () => { + // Test values + const inValues = { + a: 0, + b: 10 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("IntegerDivision", { file: "float", template: "IntegerDivision", params: [252] // Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly perform the integer division", async () => { + it("Should correctly perform the integer division [x, y]", async () => { + // Test values + const inValues = { + a: 10, + b: 2 + } + const expectedOut = 5 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("IntegerDivision", { + file: "float", + template: "IntegerDivision", + params: [252] // Assuming we're working within 252-bit numbers. + }) await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -125,31 +205,97 @@ describe("float", () => { describe("DivisionFromFloat", () => { let circuit: WitnessTester<["a", "b"], ["c"]> - // Test values - const inValues = { - a: 1000, // 10.00 - b: 200 // 2.00 - } - const expectedOut = 500 // 5.00 + it("Should throw when trying to perform division per zero [x, 0]", async () => { + // Test values + const inValues = { + a: 1000, + b: 0 + } - const INPUT = { - a: inValues.a, - b: inValues.b - } + const INPUT = { + a: inValues.a, + b: inValues.b + } - const OUTPUT = { - c: expectedOut - } + circuit = await circomkit.WitnessTester("DivisionFromFloat", { + file: "float", + template: "DivisionFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should throw when trying to perform division per negative number [x, -x]", async () => { + // Test values + const inValues = { + a: 1000, + b: -1000 + } + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + circuit = await circomkit.WitnessTester("DivisionFromFloat", { + file: "float", + template: "DivisionFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should correctly perform the integer division [0, x]", async () => { + // Test values + const inValues = { + a: 0, + b: 1000 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("DivisionFromFloat", { file: "float", template: "DivisionFromFloat", params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly perform the division from float", async () => { + it("Should correctly perform the integer division [x, y]", async () => { + // Test values + const inValues = { + a: 1000, + b: 200 + } + const expectedOut = 500 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("DivisionFromFloat", { + file: "float", + template: "DivisionFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -157,31 +303,97 @@ describe("float", () => { describe("DivisionFromNormal", () => { let circuit: WitnessTester<["a", "b"], ["c"]> - // Test values - const inValues = { - a: 10, - b: 2 - } - const expectedOut = 500 // 5.00 + it("Should throw when trying to perform division per zero [x, 0]", async () => { + // Test values + const inValues = { + a: 10, + b: 0 + } - const INPUT = { - a: inValues.a, - b: inValues.b - } + const INPUT = { + a: inValues.a, + b: inValues.b + } - const OUTPUT = { - c: expectedOut - } + circuit = await circomkit.WitnessTester("DivisionFromNormal", { + file: "float", + template: "DivisionFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should throw when trying to perform division per negative number [x, -x]", async () => { + // Test values + const inValues = { + a: 10, + b: -10 + } + + const INPUT = { + a: inValues.a, + b: inValues.b + } - before(async () => { circuit = await circomkit.WitnessTester("DivisionFromNormal", { file: "float", template: "DivisionFromNormal", params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. }) + + await circuit.expectFail(INPUT) }) - it("Should correctly perform the division from normal", async () => { + it("Should correctly perform the integer division [0, x]", async () => { + // Test values + const inValues = { + a: 0, + b: 10 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("DivisionFromNormal", { + file: "float", + template: "DivisionFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly perform the integer division [x, y]", async () => { + // Test values + const inValues = { + a: 10, + b: 2 + } + const expectedOut = 500 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("DivisionFromNormal", { + file: "float", + template: "DivisionFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -189,31 +401,128 @@ describe("float", () => { describe("MultiplicationFromFloat", () => { let circuit: WitnessTester<["a", "b"], ["c"]> - // Test values - const inValues = { - a: 1000, // 10.00 - b: 200 // 2.00 - } - const expectedOut = 2000 // 20.00 + it("Should throw when trying to perform multiplication per negative number [-x, x]", async () => { + // Test values + const inValues = { + a: -100, + b: 100 + } - const INPUT = { - a: inValues.a, - b: inValues.b - } + const INPUT = { + a: inValues.a, + b: inValues.b + } - const OUTPUT = { - c: expectedOut - } + circuit = await circomkit.WitnessTester("MultiplicationFromFloat", { + file: "float", + template: "MultiplicationFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should correctly perform the multiplication from float [0, 0]", async () => { + // Test values + const inValues = { + a: 0, + b: 0 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("MultiplicationFromFloat", { + file: "float", + template: "MultiplicationFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly perform the multiplication from float [x, 0]", async () => { + // Test values + const inValues = { + a: 100, + b: 0 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("MultiplicationFromFloat", { + file: "float", + template: "MultiplicationFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly perform the multiplication from float [0, x]", async () => { + // Test values + const inValues = { + a: 0, + b: 100 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("MultiplicationFromFloat", { file: "float", template: "MultiplicationFromFloat", params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly perform the multiplication from float", async () => { + it("Should correctly perform the multiplication from float [x, y]", async () => { + // Test values + const inValues = { + a: 100, + b: 200 + } + const expectedOut = 200 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("MultiplicationFromFloat", { + file: "float", + template: "MultiplicationFromFloat", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) @@ -221,31 +530,128 @@ describe("float", () => { describe("MultiplicationFromNormal", () => { let circuit: WitnessTester<["a", "b"], ["c"]> - // Test values - const inValues = { - a: 10, - b: 2 - } - const expectedOut = 2000 // 20.00 + it("Should throw when trying to perform multiplication per negative number [-x, x]", async () => { + // Test values + const inValues = { + a: -1, + b: 1 + } - const INPUT = { - a: inValues.a, - b: inValues.b - } + const INPUT = { + a: inValues.a, + b: inValues.b + } - const OUTPUT = { - c: expectedOut - } + circuit = await circomkit.WitnessTester("MultiplicationFromNormal", { + file: "float", + template: "MultiplicationFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectFail(INPUT) + }) + + it("Should correctly perform the multiplication from float [0, 0]", async () => { + // Test values + const inValues = { + a: 0, + b: 0 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("MultiplicationFromNormal", { + file: "float", + template: "MultiplicationFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly perform the multiplication from float [x, 0]", async () => { + // Test values + const inValues = { + a: 1, + b: 0 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("MultiplicationFromNormal", { + file: "float", + template: "MultiplicationFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + + await circuit.expectPass(INPUT, OUTPUT) + }) + + it("Should correctly perform the multiplication from float [0, x]", async () => { + // Test values + const inValues = { + a: 0, + b: 1 + } + const expectedOut = 0 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } - before(async () => { circuit = await circomkit.WitnessTester("MultiplicationFromNormal", { file: "float", template: "MultiplicationFromNormal", params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. }) + + await circuit.expectPass(INPUT, OUTPUT) }) - it("Should correctly perform the multiplication from normal", async () => { + it("Should correctly perform the multiplication from float [x, y]", async () => { + // Test values + const inValues = { + a: 1, + b: 2 + } + const expectedOut = 200 + + const INPUT = { + a: inValues.a, + b: inValues.b + } + + const OUTPUT = { + c: expectedOut + } + + circuit = await circomkit.WitnessTester("MultiplicationFromNormal", { + file: "float", + template: "MultiplicationFromNormal", + params: [2, 251] // W decimal digits, N Assuming we're working within 252-bit numbers. + }) + await circuit.expectPass(INPUT, OUTPUT) }) }) From 8027bf7d81f4b5c207ef025d83e72195776f4cc9 Mon Sep 17 00:00:00 2001 From: Jeeiii Date: Tue, 5 Mar 2024 10:49:57 +0100 Subject: [PATCH 6/6] feat: add unpack element circuit template and tests --- packages/circuits/circom/circuits.json | 6 +++ .../circuits/circom/unpack-element.circom | 31 ++++++++++++ .../circuits/tests/unpack-element.test.ts | 47 +++++++++++++++++++ 3 files changed, 84 insertions(+) create mode 100644 packages/circuits/circom/unpack-element.circom create mode 100644 packages/circuits/tests/unpack-element.test.ts diff --git a/packages/circuits/circom/circuits.json b/packages/circuits/circom/circuits.json index 6c0e24fc9..bffc07348 100644 --- a/packages/circuits/circom/circuits.json +++ b/packages/circuits/circom/circuits.json @@ -59,5 +59,11 @@ "template": "SafeGreaterEqThan", "pubs": ["in"], "params": [252] + }, + "unpack-element": { + "file": "unpack-element", + "template": "UnpackElement", + "params": [4], + "pubs": ["in"] } } diff --git a/packages/circuits/circom/unpack-element.circom b/packages/circuits/circom/unpack-element.circom new file mode 100644 index 000000000..6942966b5 --- /dev/null +++ b/packages/circuits/circom/unpack-element.circom @@ -0,0 +1,31 @@ +pragma circom 2.1.5; + +include "./bitify.circom"; + +// Template to convert a single field element into multiple 50-bit elements. +template UnpackElement(n) { + // A field element. + signal input in; + // An array of n elements, each 50 bits long. + signal output out[n]; + + // Ensure the number of outputs is more than 1 and up to 5. + assert(n > 1 && n <= 5); + + // Convert the input signal to its bit representation. + var bits[254]; + bits = Num2Bits_strict()(in); + + for (var i = 0; i < n; i++) { + var tempBits[50]; + + // Select and assign the appropriate 50-bit segment of the input's bit representation. + for (var j = 0; j < 50; j++) { + // Calculate the bit's index, considering the output element's position. + tempBits[j] = bits[((n - i - 1) * 50) + j]; + } + + // Assign the numerical value of the 50-bit segment to the output signal. + out[i] <== Bits2Num(50)(tempBits); + } +} diff --git a/packages/circuits/tests/unpack-element.test.ts b/packages/circuits/tests/unpack-element.test.ts new file mode 100644 index 000000000..c92d232c8 --- /dev/null +++ b/packages/circuits/tests/unpack-element.test.ts @@ -0,0 +1,47 @@ +import { WitnessTester } from "circomkit" +import { circomkit } from "./common" + +describe("unpack-element", () => { + let circuit: WitnessTester<["in"], ["out"]> + + it("Should unpack a field element with 3 / 4 / 5 packed values correctly", async () => { + for (let n = 3; n <= 5; n += 1) { + const elements: string[] = [] + + circuit = await circomkit.WitnessTester("unpack-element", { + file: "unpack-element", + template: "UnpackElement", + params: [n] + }) + + for (let i = 0; i < n; i += 1) { + let e = (BigInt(i) % BigInt(2 ** 50)).toString(2) + + while (e.length < 50) { + e = `0${e}` + } + + elements.push(e) + } + + const INPUT = { + in: BigInt(`0b${elements.join("")}`) + } + + const witness = await circuit.calculateWitness(INPUT) + await circuit.expectConstraintPass(witness) + + const outputs = [ + [BigInt(0), BigInt(1), BigInt(2)], + [BigInt(0), BigInt(1), BigInt(2), BigInt(3)], + [BigInt(0), BigInt(1), BigInt(2), BigInt(3), BigInt(4)] + ] + + const OUTPUT = { + out: outputs[n - 3] + } + + await circuit.expectPass(INPUT, OUTPUT) + } + }) +})