Skip to content

Commit

Permalink
BLS: Refactor mask-bit settings, improve encoding resiliency
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmillr committed Nov 10, 2023
1 parent fb02e93 commit 2f1460a
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 67 deletions.
149 changes: 83 additions & 66 deletions src/bls12-381.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import {
numberToBytesBE,
bytesToNumberBE,
bitLen,
bitSet,
bitGet,
Hex,
bitMask,
Expand Down Expand Up @@ -1019,22 +1018,40 @@ const htfDefaults = Object.freeze({

// Encoding utils
// Point on G1 curve: (x, y)
const C_BIT_POS = Fp.BITS; // C_bit, compression bit for serialization flag
const I_BIT_POS = Fp.BITS + 1; // I_bit, point-at-infinity bit for serialization flag
const S_BIT_POS = Fp.BITS + 2; // S_bit, sign bit for serialization flag

// Compressed point of infinity
const COMPRESSED_ZERO = Fp.toBytes(bitSet(bitSet(_0n, I_BIT_POS, true), S_BIT_POS, true)); // set compressed & point-at-infinity bits
const COMPRESSED_ZERO = setMask(Fp.toBytes(_0n), { infinity: true, compressed: true }); // set compressed & point-at-infinity bits

function parseMask(bytes: Uint8Array) {
// Copy, so we can remove mask data. It will be removed also later, when Fp.create will call modulo.
bytes = bytes.slice();
const mask = bytes[0] & 0b1110_0000;
const compressed = !!((mask >> 7) & 1); // compression bit (0b1000_0000)
const infinity = !!((mask >> 6) & 1); // point at infinity bit (0b0100_0000)
const sort = !!((mask >> 5) & 1); // sort bit (0b0010_0000)
bytes[0] &= 0b0001_1111; // clear mask (zero first 3 bits)
return { compressed, infinity, sort, value: bytes };
}

function setMask(
bytes: Uint8Array,
mask: { compressed?: boolean; infinity?: boolean; sort?: boolean }
) {
if (bytes[0] & 0b1110_0000) throw new Error('setMask: non-empty mask');
if (mask.compressed) bytes[0] |= 0b1000_0000;
if (mask.infinity) bytes[0] |= 0b0100_0000;
if (mask.sort) bytes[0] |= 0b0010_0000;
return bytes;
}

function signatureG1ToRawBytes(point: ProjPointType<Fp>) {
point.assertValidity();
const isZero = point.equals(bls12_381.G1.ProjectivePoint.ZERO);
const { x, y } = point.toAffine();
if (isZero) return COMPRESSED_ZERO.slice();
const P = Fp.ORDER;
let num;
num = bitSet(x, C_BIT_POS, Boolean((y * _2n) / P)); // set aflag
num = bitSet(num, S_BIT_POS, true);
return numberToBytesBE(num, Fp.BYTES);
const sort = Boolean((y * _2n) / P);
return setMask(numberToBytesBE(x, Fp.BYTES), { compressed: true, sort });
}

function signatureG2ToRawBytes(point: ProjPointType<Fp2>) {
Expand All @@ -1047,10 +1064,12 @@ function signatureG2ToRawBytes(point: ProjPointType<Fp2>) {
const { re: x0, im: x1 } = Fp2.reim(x);
const { re: y0, im: y1 } = Fp2.reim(y);
const tmp = y1 > _0n ? y1 * _2n : y0 * _2n;
const aflag1 = Boolean((tmp / Fp.ORDER) & _1n);
const z1 = bitSet(bitSet(x1, 381, aflag1), S_BIT_POS, true);
const sort = Boolean((tmp / Fp.ORDER) & _1n);
const z2 = x0;
return concatB(numberToBytesBE(z1, len), numberToBytesBE(z2, len));
return concatB(
setMask(numberToBytesBE(x1, len), { sort, compressed: true }),
numberToBytesBE(z2, len)
);
}

// To verify curve parameters, see pairing-friendly-curves spec:
Expand Down Expand Up @@ -1131,26 +1150,30 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
return isogenyMapG1(x, y);
},
fromBytes: (bytes: Uint8Array): AffinePoint<Fp> => {
bytes = bytes.slice();
if (bytes.length === 48) {
const { compressed, infinity, sort, value } = parseMask(bytes);
if (value.length === 48 && compressed) {
// TODO: Fp.bytes
const P = Fp.ORDER;
const compressedValue = bytesToNumberBE(bytes);
const bflag = bitGet(compressedValue, I_BIT_POS);
const compressedValue = bytesToNumberBE(value);
// Zero
if (bflag === _1n) return { x: _0n, y: _0n };
const x = Fp.create(compressedValue & Fp.MASK);
if (infinity) {
if (x !== _0n) throw new Error('G1: non-empty compressed point at infinity');
return { x: _0n, y: _0n };
}
const right = Fp.add(Fp.pow(x, _3n), Fp.create(bls12_381.params.G1b)); // y² = x³ + b
let y = Fp.sqrt(right);
if (!y) throw new Error('Invalid compressed G1 point');
const aflag = bitGet(compressedValue, C_BIT_POS);
if ((y * _2n) / P !== aflag) y = Fp.neg(y);
if ((y * _2n) / P !== BigInt(sort)) y = Fp.neg(y);
return { x: Fp.create(x), y: Fp.create(y) };
} else if (bytes.length === 96) {
} else if (value.length === 96 && !compressed) {
// Check if the infinity flag is set
if ((bytes[0] & (1 << 6)) !== 0) return bls12_381.G1.ProjectivePoint.ZERO.toAffine();
const x = bytesToNumberBE(bytes.subarray(0, Fp.BYTES));
const y = bytesToNumberBE(bytes.subarray(Fp.BYTES));
const x = bytesToNumberBE(value.subarray(0, Fp.BYTES));
const y = bytesToNumberBE(value.subarray(Fp.BYTES));
if (infinity) {
if (x !== _0n || y !== _0n) throw new Error('G1: non-empty point at infinity');
return bls12_381.G1.ProjectivePoint.ZERO.toAffine();
}
return { x: Fp.create(x), y: Fp.create(y) };
} else {
throw new Error('Invalid point G1, expected 48/96 bytes');
Expand All @@ -1162,10 +1185,8 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
if (isCompressed) {
if (isZero) return COMPRESSED_ZERO.slice();
const P = Fp.ORDER;
let num;
num = bitSet(x, C_BIT_POS, Boolean((y * _2n) / P)); // set aflag
num = bitSet(num, S_BIT_POS, true);
return numberToBytesBE(num, Fp.BYTES);
const sort = Boolean((y * _2n) / P);
return setMask(numberToBytesBE(x, Fp.BYTES), { compressed: true, sort });
} else {
if (isZero) {
// 2x PUBLIC_KEY_LENGTH
Expand All @@ -1178,18 +1199,16 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
},
ShortSignature: {
fromHex(hex: Hex): ProjPointType<Fp> {
const bytes = ensureBytes('signatureHex', hex, 48);

const { infinity, sort, value } = parseMask(ensureBytes('signatureHex', hex, 48));
const P = Fp.ORDER;
const compressedValue = bytesToNumberBE(bytes);
const bflag = bitGet(compressedValue, I_BIT_POS);
const compressedValue = bytesToNumberBE(value);
// Zero
if (bflag === _1n) return bls12_381.G1.ProjectivePoint.ZERO;
if (infinity) return bls12_381.G1.ProjectivePoint.ZERO;
const x = Fp.create(compressedValue & Fp.MASK);
const right = Fp.add(Fp.pow(x, _3n), Fp.create(bls12_381.params.G1b)); // y² = x³ + b
let y = Fp.sqrt(right);
if (!y) throw new Error('Invalid compressed G1 point');
const aflag = bitGet(compressedValue, C_BIT_POS);
const aflag = BigInt(sort);
if ((y * _2n) / P !== aflag) y = Fp.neg(y);
const point = bls12_381.G1.ProjectivePoint.fromAffine({ x, y });
point.assertValidity();
Expand Down Expand Up @@ -1273,45 +1292,45 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
return Q; // [x²-x-1]P + [x-1]Ψ(P) + Ψ²(2P)
},
fromBytes: (bytes: Uint8Array): AffinePoint<Fp2> => {
bytes = bytes.slice();
const m_byte = bytes[0] & 0xe0;
if (m_byte === 0x20 || m_byte === 0x60 || m_byte === 0xe0) {
throw new Error('Invalid encoding flag: ' + m_byte);
const { compressed, infinity, sort, value } = parseMask(bytes);
if (
(!compressed && !infinity && sort) || // 00100000
(!compressed && infinity && sort) || // 01100000
(sort && infinity && compressed) // 11100000
) {
throw new Error('Invalid encoding flag: ' + (bytes[0] & 0b1110_0000));
}
const bitC = m_byte & 0x80; // compression bit
const bitI = m_byte & 0x40; // point at infinity bit
const bitS = m_byte & 0x20; // sign bit
const L = Fp.BYTES;
const slc = (b: Uint8Array, from: number, to?: number) => bytesToNumberBE(b.slice(from, to));
if (bytes.length === 96 && bitC) {
if (value.length === 96 && compressed) {
const b = bls12_381.params.G2b;
const P = Fp.ORDER;

bytes[0] = bytes[0] & 0x1f; // clear flags
if (bitI) {
if (infinity) {
// check that all bytes are 0
if (bytes.reduce((p, c) => (p !== 0 ? c + 1 : c), 0) > 0) {
if (value.reduce((p, c) => (p !== 0 ? c + 1 : c), 0) > 0) {
throw new Error('Invalid compressed G2 point');
}
return { x: Fp2.ZERO, y: Fp2.ZERO };
}
const x_1 = slc(bytes, 0, L);
const x_0 = slc(bytes, L, 2 * L);
const x_1 = slc(value, 0, L);
const x_0 = slc(value, L, 2 * L);
const x = Fp2.create({ c0: Fp.create(x_0), c1: Fp.create(x_1) });
const right = Fp2.add(Fp2.pow(x, _3n), b); // y² = x³ + 4 * (u+1) = x³ + b
let y = Fp2.sqrt(right);
const Y_bit = y.c1 === _0n ? (y.c0 * _2n) / P : (y.c1 * _2n) / P ? _1n : _0n;
y = bitS > 0 && Y_bit > 0 ? y : Fp2.neg(y);
y = sort && Y_bit > 0 ? y : Fp2.neg(y);
return { x, y };
} else if (bytes.length === 192 && !bitC) {
// Check if the infinity flag is set
if ((bytes[0] & (1 << 6)) !== 0) {
} else if (value.length === 192 && !compressed) {
if (infinity) {
if (value.reduce((p, c) => (p !== 0 ? c + 1 : c), 0) > 0) {
throw new Error('Invalid uncompressed G2 point');
}
return { x: Fp2.ZERO, y: Fp2.ZERO };
}
const x1 = slc(bytes, 0, L);
const x0 = slc(bytes, L, 2 * L);
const y1 = slc(bytes, 2 * L, 3 * L);
const y0 = slc(bytes, 3 * L, 4 * L);
const x1 = slc(value, 0, L);
const x0 = slc(value, L, 2 * L);
const y1 = slc(value, 2 * L, 3 * L);
const y0 = slc(value, 3 * L, 4 * L);
return { x: Fp2.fromBigTuple([x0, x1]), y: Fp2.fromBigTuple([y0, y1]) };
} else {
throw new Error('Invalid point G2, expected 96/192 bytes');
Expand All @@ -1324,10 +1343,10 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
if (isCompressed) {
if (isZero) return concatB(COMPRESSED_ZERO, numberToBytesBE(_0n, len));
const flag = Boolean(y.c1 === _0n ? (y.c0 * _2n) / P : (y.c1 * _2n) / P);
// set compressed & sign bits (looks like different offsets than for G1/Fp?)
let x_1 = bitSet(x.c1, C_BIT_POS, flag);
x_1 = bitSet(x_1, S_BIT_POS, true);
return concatB(numberToBytesBE(x_1, len), numberToBytesBE(x.c0, len));
return concatB(
setMask(numberToBytesBE(x.c1, len), { compressed: true, sort: flag }),
numberToBytesBE(x.c0, len)
);
} else {
if (isZero) return concatB(new Uint8Array([0x40]), new Uint8Array(4 * len - 1)); // bytes[0] |= 1 << 6;
const { re: x0, im: x1 } = Fp2.reim(x);
Expand All @@ -1343,17 +1362,15 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
Signature: {
// TODO: Optimize, it's very slow because of sqrt.
fromHex(hex: Hex): ProjPointType<Fp2> {
hex = ensureBytes('signatureHex', hex);
const { infinity, sort, value } = parseMask(ensureBytes('signatureHex', hex));
const P = Fp.ORDER;
const half = hex.length / 2;
if (half !== 48 && half !== 96)
throw new Error('Invalid compressed signature length, must be 96 or 192');
const z1 = bytesToNumberBE(hex.slice(0, half));
const z2 = bytesToNumberBE(hex.slice(half));
const z1 = bytesToNumberBE(value.slice(0, half));
const z2 = bytesToNumberBE(value.slice(half));
// Indicates the infinity point
const bflag1 = bitGet(z1, I_BIT_POS);
if (bflag1 === _1n) return bls12_381.G2.ProjectivePoint.ZERO;

if (infinity) return bls12_381.G2.ProjectivePoint.ZERO;
const x1 = Fp.create(z1 & Fp.MASK);
const x2 = Fp.create(z2);
const x = Fp2.create({ c0: x2, c1: x1 });
Expand All @@ -1365,7 +1382,7 @@ export const bls12_381: CurveFn<Fp, Fp2, Fp6, Fp12> = bls({
// Choose the y whose leftmost bit of the imaginary part is equal to the a_flag1
// If y1 happens to be zero, then use the bit of y0
const { re: y0, im: y1 } = Fp2.reim(y);
const aflag1 = bitGet(z1, 381);
const aflag1 = BigInt(sort);
const isGreater = y1 > _0n && (y1 * _2n) / P !== aflag1;
const isZero = y1 === _0n && (y0 * _2n) / P !== aflag1;
if (isGreater || isZero) y = Fp2.neg(y);
Expand Down
35 changes: 34 additions & 1 deletion test/bls12-381.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import { describe, should } from 'micro-should';
import { wNAF } from '../esm/abstract/curve.js';
import { bytesToHex, utf8ToBytes } from '../esm/abstract/utils.js';
import { hash_to_field } from '../esm/abstract/hash-to-curve.js';
import { bls12_381 as bls } from '../esm/bls12-381.js';
import { bls12_381 as bls, bls12_381 } from '../esm/bls12-381.js';

import * as utils from '../esm/abstract/utils.js';

import zkVectors from './bls12-381/zkcrypto/converted.json' assert { type: 'json' };
import pairingVectors from './bls12-381/go_pairing_vectors/pairing.json' assert { type: 'json' };
Expand Down Expand Up @@ -1415,6 +1417,37 @@ describe('bls12-381 deterministic', () => {
}
}
});
should(`zkcrypt/G1 & G2 encoding edge cases`, () => {
const Fp = bls12_381.fields.Fp;
const S_BIT_POS = Fp.BITS; // C_bit, compression bit for serialization flag
const I_BIT_POS = Fp.BITS + 1; // I_bit, point-at-infinity bit for serialization flag
const C_BIT_POS = Fp.BITS + 2; // S_bit, sort bit for serialization flag
const VECTORS = [
{ pos: C_BIT_POS, shift: 7 }, // compression_flag_set = Choice::from((bytes[0] >> 7) & 1);
{ pos: I_BIT_POS, shift: 6 }, // infinity_flag_set = Choice::from((bytes[0] >> 6) & 1)
{ pos: S_BIT_POS, shift: 5 }, // sort_flag_set = Choice::from((bytes[0] >> 5) & 1)
];
for (const { pos, shift } of VECTORS) {
const d = utils.numberToBytesBE(utils.bitSet(0n, pos, Boolean(true)), Fp.BYTES);
deepStrictEqual((d[0] >> shift) & 1, 1, `${pos}`);
}
const baseC = G1Point.BASE.toRawBytes();
deepStrictEqual(baseC.length, 48);
const baseU = G1Point.BASE.toRawBytes(false);
deepStrictEqual(baseU.length, 96);
const compressedBit = baseU.slice();
compressedBit[0] |= 0b1000_0000; // add compression bit
throws(() => G1Point.fromHex(compressedBit), 'compressed bit'); // uncompressed point with compressed length
const uncompressedBit = baseC.slice();
uncompressedBit[0] &= 0b0111_1111; // remove compression bit
throws(() => G1Point.fromHex(uncompressedBit), 'uncompressed bit');
const infinityUncompressed = baseU.slice();
infinityUncompressed[0] |= 0b0100_0000;
throws(() => G1Point.fromHex(compressedBit), 'infinity uncompressed');
const infinityCompressed = baseC.slice();
infinityCompressed[0] |= 0b0100_0000;
throws(() => G1Point.fromHex(compressedBit), 'infinity compressed');
});
});

// ESM is broken.
Expand Down

0 comments on commit 2f1460a

Please sign in to comment.